diff --git a/backend/src/baserow/contrib/database/fields/field_helpers.py b/backend/src/baserow/contrib/database/fields/field_helpers.py index 4e043b743..78ff79303 100644 --- a/backend/src/baserow/contrib/database/fields/field_helpers.py +++ b/backend/src/baserow/contrib/database/fields/field_helpers.py @@ -254,8 +254,20 @@ def construct_all_possible_field_kwargs( "name": "ai", "ai_generative_ai_type": "test_generative_ai", "ai_generative_ai_model": "test_1", + "ai_output_type": "text", "ai_prompt": "'Who are you?'", - } + }, + { + "name": "ai_choice", + "ai_generative_ai_type": "test_generative_ai", + "ai_generative_ai_model": "test_1", + "ai_prompt": "'What are you?'", + "ai_output_type": "choice", + "select_options": [ + {"id": 5, "value": "Object", "color": "orange"}, + {"id": 6, "value": "Else", "color": "yellow"}, + ], + }, ], } # If you have added a new field please add an entry into the dict above with any diff --git a/backend/src/baserow/contrib/database/fields/registries.py b/backend/src/baserow/contrib/database/fields/registries.py index 3009275aa..f3c106cce 100644 --- a/backend/src/baserow/contrib/database/fields/registries.py +++ b/backend/src/baserow/contrib/database/fields/registries.py @@ -975,7 +975,7 @@ class FieldType( field_id = serialized_copy.pop("id") serialized_copy.pop("type") select_options = ( - serialized_copy.pop("select_options") + serialized_copy.pop("select_options", []) if self.can_have_select_options else [] ) diff --git a/backend/src/baserow/contrib/database/views/handler.py b/backend/src/baserow/contrib/database/views/handler.py index 13ae8bdef..0ff3c598d 100644 --- a/backend/src/baserow/contrib/database/views/handler.py +++ b/backend/src/baserow/contrib/database/views/handler.py @@ -1976,6 +1976,8 @@ class ViewHandler(metaclass=baserow_trace_methods(tracer)): :return: The created view sort instance. """ + field = field.specific + workspace = view.table.database.workspace CoreHandler().check_permissions( user, ReadFieldOperationType.type, workspace=workspace, context=field diff --git a/backend/src/baserow/contrib/database/ws/public/rows/signals.py b/backend/src/baserow/contrib/database/ws/public/rows/signals.py index 9024d08cd..7e440bc05 100644 --- a/backend/src/baserow/contrib/database/ws/public/rows/signals.py +++ b/backend/src/baserow/contrib/database/ws/public/rows/signals.py @@ -274,6 +274,7 @@ def public_rows_updated( table_id=PUBLIC_PLACEHOLDER_ENTITY_ID, serialized_rows_before_update=visible_fields_only_old_rows, serialized_rows=visible_fields_only_updated_rows, + updated_field_ids=list(updated_field_ids), metadata={}, ), slug=public_view.slug, diff --git a/backend/src/baserow/contrib/database/ws/rows/signals.py b/backend/src/baserow/contrib/database/ws/rows/signals.py index 359f3f94b..a847e7ba6 100644 --- a/backend/src/baserow/contrib/database/ws/rows/signals.py +++ b/backend/src/baserow/contrib/database/ws/rows/signals.py @@ -81,6 +81,9 @@ def rows_updated( serialized_rows=get_row_serializer_class( model, RowSerializer, is_response=True )(rows, many=True).data, + # Broadcast a list of updated fields so that the listener can take + # action even if the value didn't change. + updated_field_ids=list(updated_field_ids), metadata=row_metadata_registry.generate_and_merge_metadata_for_rows( user, table, [row.id for row in rows] ), @@ -213,6 +216,7 @@ class RealtimeRowMessages: serialized_rows_before_update: List[Dict[str, Any]], serialized_rows: List[Dict[str, Any]], metadata: Dict[int, Dict[str, Any]], + updated_field_ids: List[int], ) -> Dict[str, Any]: return { "type": "rows_updated", @@ -223,6 +227,7 @@ class RealtimeRowMessages: "rows_before_update": serialized_rows_before_update, "rows": serialized_rows, "metadata": metadata, + "updated_field_ids": updated_field_ids, } @staticmethod diff --git a/backend/src/baserow/test_utils/helpers.py b/backend/src/baserow/test_utils/helpers.py index 41bc956b9..8df3b9dcb 100644 --- a/backend/src/baserow/test_utils/helpers.py +++ b/backend/src/baserow/test_utils/helpers.py @@ -250,6 +250,9 @@ def setup_interesting_test_table( "phone_number": "+4412345678", "password": "test", "ai": "I'm an AI.", + "ai_choice": SelectOption.objects.get( + value="Object", field_id=name_to_field_id["ai_choice"] + ).id, } with freeze_time("2020-02-01 01:23"): diff --git a/backend/tests/baserow/contrib/database/api/rows/test_row_serializers.py b/backend/tests/baserow/contrib/database/api/rows/test_row_serializers.py index ecbbd2b0a..0e3aad106 100644 --- a/backend/tests/baserow/contrib/database/api/rows/test_row_serializers.py +++ b/backend/tests/baserow/contrib/database/api/rows/test_row_serializers.py @@ -378,6 +378,11 @@ def test_get_row_serializer_with_user_field_names(data_fixture): "autonumber": 2, "password": True, "ai": "I'm an AI.", + "ai_choice": { + "color": "orange", + "id": SelectOption.objects.get(value="Object").id, + "value": "Object", + }, } ) ) diff --git a/backend/tests/baserow/contrib/database/api/views/test_view_serializers.py b/backend/tests/baserow/contrib/database/api/views/test_view_serializers.py index 9a383be91..ecc0bebe6 100644 --- a/backend/tests/baserow/contrib/database/api/views/test_view_serializers.py +++ b/backend/tests/baserow/contrib/database/api/views/test_view_serializers.py @@ -84,6 +84,8 @@ def test_serialize_group_by_metadata_on_all_fields_in_interesting_table(data_fix actual_result_per_field_name = {} + ai_choice_select_options = Field.objects.get(name="ai_choice").select_options.all() + for field in fields_to_group_by: counts = handler.get_group_by_metadata_in_rows([field], rows, queryset) serialized = serialize_group_by_metadata(counts)[field.db_column] @@ -258,4 +260,12 @@ def test_serialize_group_by_metadata_on_all_fields_in_interesting_table(data_fix {"count": 1, "field_duration_dhms": 90066.0}, {"count": 1, "field_duration_dhms": None}, ], + "ai": [ + {"count": 1, "field_ai": "I'm an AI."}, + {"count": 1, "field_ai": None}, + ], + "ai_choice": [ + {"count": 1, "field_ai_choice": ai_choice_select_options[0].id}, + {"count": 1, "field_ai_choice": None}, + ], } diff --git a/backend/tests/baserow/contrib/database/field/test_link_row_field_type.py b/backend/tests/baserow/contrib/database/field/test_link_row_field_type.py index 4b243761a..8a9e31ba1 100644 --- a/backend/tests/baserow/contrib/database/field/test_link_row_field_type.py +++ b/backend/tests/baserow/contrib/database/field/test_link_row_field_type.py @@ -303,7 +303,9 @@ def test_link_row_field_type_with_text_values(data_fixture): for field_type in [ f for f in field_type_registry.get_all() - if f.can_get_unique_values and not f.read_only + # The AI field is not compatible because it some field kwargs are required and + # not passed in. + if f.can_get_unique_values and not f.read_only and f.type != "ai" ]: field_type_name = field_type.type field_name = f"Field {field_type_name}" diff --git a/backend/tests/baserow/contrib/database/import_export/test_export_handler.py b/backend/tests/baserow/contrib/database/import_export/test_export_handler.py index 23d7152ce..7104f712a 100755 --- a/backend/tests/baserow/contrib/database/import_export/test_export_handler.py +++ b/backend/tests/baserow/contrib/database/import_export/test_export_handler.py @@ -247,12 +247,12 @@ def test_can_export_every_interesting_different_field_to_csv( "phone_number,formula_text,formula_int,formula_bool,formula_decimal,formula_dateinterval," "formula_date,formula_singleselect,formula_email,formula_link_with_label," "formula_link_url_only,formula_multipleselect,count,rollup,duration_rollup_sum," - "duration_rollup_avg,lookup,uuid,autonumber,password,ai\r\n" + "duration_rollup_avg,lookup,uuid,autonumber,password,ai,ai_choice\r\n" "1,,,,,,,,,0,False,,,,,,,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021," "02/01/2021 13:00,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,02/01/2021 13:00," "user@example.com,user@example.com,,,,,,,,,,,,,,,,,,,test FORMULA,1,True,33.3333333333," "1d 0:00,2020-01-01,,,label (https://google.com),https://google.com,,0,0.000," - "0:00,0:00,,00000000-0000-4000-8000-000000000001,1,,\r\n" + "0:00,0:00,,00000000-0000-4000-8000-000000000001,1,,,\r\n" "2,text,long_text,https://www.google.com,test@example.com,-1,1,-1.2,1.2,3,True," "02/01/2020 01:23,02/01/2020,01/02/2020 01:23,01/02/2020,01/02/2020 02:23," "01/02/2020 02:23,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021," @@ -267,7 +267,7 @@ def test_can_export_every_interesting_different_field_to_csv( '"user2@example.com,user3@example.com",\'+4412345678,test FORMULA,1,True,33.3333333333,' "1d 0:00,2020-01-01,A,test@example.com,label (https://google.com),https://google.com," '"C,D,E",3,-122.222,0:04,0:02,"linked_row_1,linked_row_2,",' - "00000000-0000-4000-8000-000000000002,2,True,I'm an AI.\r\n" + "00000000-0000-4000-8000-000000000002,2,True,I'm an AI.,Object\r\n" ) assert contents == expected diff --git a/backend/tests/baserow/contrib/database/ws/public/test_public_ws_rows_signals.py b/backend/tests/baserow/contrib/database/ws/public/test_public_ws_rows_signals.py index d7233cf5f..4a52b21ec 100644 --- a/backend/tests/baserow/contrib/database/ws/public/test_public_ws_rows_signals.py +++ b/backend/tests/baserow/contrib/database/ws/public/test_public_ws_rows_signals.py @@ -1014,6 +1014,7 @@ def test_batch_update_rows_some_not_visible_in_public_view_to_be_visible_event_s }, ], "metadata": {}, + "updated_field_ids": [hidden_field.id], }, None, None, @@ -1151,6 +1152,7 @@ def test_batch_update_rows_visible_in_public_view_to_some_not_be_visible_event_s }, ], "metadata": {}, + "updated_field_ids": [hidden_field.id], }, None, None, @@ -1454,6 +1456,7 @@ def test_given_row_visible_in_public_view_when_updated_to_still_be_visible_event } ], "metadata": {}, + "updated_field_ids": [visible_field.id, hidden_field.id], }, None, None, @@ -1580,6 +1583,7 @@ def test_batch_update_rows_visible_in_public_view_still_be_visible_event_sent( }, ], "metadata": {}, + "updated_field_ids": [visible_field.id, hidden_field.id], }, None, None, @@ -1661,6 +1665,7 @@ def test_batch_update_subset_rows_visible_in_public_view_no_filters( }, ], "metadata": {}, + "updated_field_ids": [visible_field.id], }, None, None, @@ -2021,6 +2026,7 @@ def test_given_row_visible_in_public_view_when_moved_row_updated_sent( } ], "metadata": {}, + "updated_field_ids": [], }, None, None, diff --git a/backend/tests/baserow/contrib/database/ws/test_ws_rows_signals.py b/backend/tests/baserow/contrib/database/ws/test_ws_rows_signals.py index 3825d29a3..5dafc0faf 100644 --- a/backend/tests/baserow/contrib/database/ws/test_ws_rows_signals.py +++ b/backend/tests/baserow/contrib/database/ws/test_ws_rows_signals.py @@ -315,6 +315,7 @@ def test_rows_history_updated(mock_broadcast_channel_group, data_fixture): ), ], "metadata": {}, + "updated_field_ids": [field.id], }, None, None, diff --git a/backend/tests/baserow/contrib/integrations/local_baserow/test_service_types.py b/backend/tests/baserow/contrib/integrations/local_baserow/test_service_types.py index d153c117a..1cac4a23e 100644 --- a/backend/tests/baserow/contrib/integrations/local_baserow/test_service_types.py +++ b/backend/tests/baserow/contrib/integrations/local_baserow/test_service_types.py @@ -864,6 +864,21 @@ def test_local_baserow_table_service_generate_schema_with_interesting_test_table "metadata": {}, "type": "string", }, + field_db_column_by_name["ai_choice"]: { + "title": "ai_choice", + "default": None, + "searchable": True, + "sortable": True, + "filterable": False, + "original_type": "ai", + "metadata": {}, + "properties": { + "color": {"title": "color", "type": "string"}, + "id": {"title": "id", "type": "number"}, + "value": {"title": "value", "type": "string"}, + }, + "type": "object", + }, "id": { "type": "number", "title": "Id", diff --git a/changelog/entries/unreleased/feature/3143_ai_field_choice_output_type.json b/changelog/entries/unreleased/feature/3143_ai_field_choice_output_type.json new file mode 100644 index 000000000..e90312194 --- /dev/null +++ b/changelog/entries/unreleased/feature/3143_ai_field_choice_output_type.json @@ -0,0 +1,7 @@ +{ + "type": "feature", + "message": "AI choice output type for classification purposes.", + "issue_number": 3143, + "bullet_points": [], + "created_at": "2024-11-10" +} diff --git a/enterprise/backend/src/baserow_enterprise/data_sync/baserow_table_data_sync.py b/enterprise/backend/src/baserow_enterprise/data_sync/baserow_table_data_sync.py index f03df1cb5..b78ba792e 100644 --- a/enterprise/backend/src/baserow_enterprise/data_sync/baserow_table_data_sync.py +++ b/enterprise/backend/src/baserow_enterprise/data_sync/baserow_table_data_sync.py @@ -5,6 +5,7 @@ from uuid import UUID from django.db.models import Prefetch from baserow_premium.fields.field_types import AIFieldType +from baserow_premium.fields.registries import ai_field_output_registry from baserow_premium.license.handler import LicenseHandler from rest_framework import serializers @@ -33,7 +34,6 @@ from baserow.contrib.database.fields.field_types import ( from baserow.contrib.database.fields.models import ( DateField, Field, - LongTextField, NumberField, SelectOption, TextField, @@ -51,16 +51,21 @@ from baserow_enterprise.features import DATA_SYNC from .models import LocalBaserowTableDataSync -def prepare_single_select_value(value, enabled_property): +def prepare_single_select_value(value, field, metadata): try: # The metadata contains a mapping of the select options where the key is the # old ID and the value is the new ID. For some reason the key is converted to # a string when moved into the JSON field. - return enabled_property.metadata["select_options_mapping"][str(value)] + return metadata["select_options_mapping"][str(value)] except (KeyError, TypeError): return None +def prepare_ai_value(value, field, metadata): + output_type = ai_field_output_registry.get(field.ai_output_type) + return output_type.prepare_data_sync_value(value, field, metadata) + + # List of Baserow supported field types that can be included in the data sync. supported_field_types = { TextFieldType.type: {}, @@ -78,7 +83,7 @@ supported_field_types = { LastModifiedFieldType.type: {}, UUIDFieldType.type: {}, AutonumberFieldType.type: {}, - AIFieldType.type: {}, + AIFieldType.type: {"prepare_value": prepare_ai_value}, SingleSelectFieldType.type: {"prepare_value": prepare_single_select_value}, } @@ -99,7 +104,6 @@ class BaserowFieldDataSyncProperty(DataSyncProperty): LastModifiedFieldType.type: DateField, UUIDFieldType.type: TextField, AutonumberFieldType.type: NumberField, - AIFieldType.type: LongTextField, } def __init__(self, field, immutable_properties, **kwargs): @@ -332,7 +336,9 @@ class LocalBaserowTableDataSyncType(DataSyncType): if "prepare_value" in supported_field: for row in rows_queryset: row[enabled_property.key] = supported_field["prepare_value"]( - row[enabled_property.key], enabled_property + row[enabled_property.key], + enabled_property.field, + enabled_property.metadata, ) progress.increment(by=2) # makes the total `10` diff --git a/enterprise/backend/tests/baserow_enterprise_tests/data_sync/test_local_baserow_table_data_sync_type.py b/enterprise/backend/tests/baserow_enterprise_tests/data_sync/test_local_baserow_table_data_sync_type.py index 20098ee82..4603c6c43 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/data_sync/test_local_baserow_table_data_sync_type.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/data_sync/test_local_baserow_table_data_sync_type.py @@ -389,6 +389,7 @@ def test_sync_data_sync_table_with_interesting_table_as_source(enterprise_data_f "last_modified_datetime_eu_tzone": "02/01/2021 13:00", "autonumber": "1", "ai": "", + "ai_choice": "", "uuid": "00000000-0000-4000-8000-000000000001", } assert results == { @@ -432,6 +433,7 @@ def test_sync_data_sync_table_with_interesting_table_as_source(enterprise_data_f "last_modified_datetime_eu_tzone": "02/01/2021 13:00", "autonumber": "2", "ai": "I'm an AI.", + "ai_choice": "Object", "uuid": "00000000-0000-4000-8000-000000000002", } diff --git a/premium/backend/src/baserow_premium/apps.py b/premium/backend/src/baserow_premium/apps.py index 26a16cee0..c61c2dde6 100644 --- a/premium/backend/src/baserow_premium/apps.py +++ b/premium/backend/src/baserow_premium/apps.py @@ -23,13 +23,21 @@ class BaserowPremiumConfig(AppConfig): ) from .fields.actions import GenerateFormulaWithAIActionType + from .fields.ai_field_output_types import ( + ChoiceAIFieldOutputType, + TextAIFieldOutputType, + ) from .fields.field_converters import AIFieldConverter from .fields.field_types import AIFieldType + from .fields.registries import ai_field_output_registry field_type_registry.register(AIFieldType()) field_converter_registry.register(AIFieldConverter()) + ai_field_output_registry.register(TextAIFieldOutputType()) + ai_field_output_registry.register(ChoiceAIFieldOutputType()) + from baserow.contrib.database.rows.registries import row_metadata_registry from baserow.contrib.database.views.registries import ( decorator_type_registry, diff --git a/premium/backend/src/baserow_premium/fields/ai_field_output_types.py b/premium/backend/src/baserow_premium/fields/ai_field_output_types.py new file mode 100644 index 000000000..2d8e32018 --- /dev/null +++ b/premium/backend/src/baserow_premium/fields/ai_field_output_types.py @@ -0,0 +1,97 @@ +import enum +import json +from difflib import get_close_matches +from typing import Any + +from langchain.output_parsers.enum import EnumOutputParser +from langchain_core.exceptions import OutputParserException +from langchain_core.prompts import PromptTemplate + +from baserow.contrib.database.fields.field_types import ( + LongTextFieldType, + SingleSelectFieldType, +) + +from .registries import AIFieldOutputType + + +class TextAIFieldOutputType(AIFieldOutputType): + type = "text" + baserow_field_type = LongTextFieldType + + +class StrictEnumOutputParser(EnumOutputParser): + def get_format_instructions(self) -> str: + json_array = json.dumps(self._valid_values) + return f"""Categorize the result following these requirements: + +- Select only one option from the JSON array below. +- Don't use quotes or commas or partial values, just the option name. +- Choose the option that most closely matches the row values. + +```json +{json_array} +```""" # nosec this falsely marks as hardcoded sql expression, but it's not related + # to SQL at all. + + def parse(self, response: str) -> Any: + response = response.strip() + # Sometimes the LLM responds with a quotes value or with part of the value if + # it contains a comma. Finding the close matches helps with selecting the + # right value. + closest_matches = get_close_matches( + response, self._valid_values, n=1, cutoff=0.0 + ) + return super().parse(closest_matches[0]) + + +class ChoiceAIFieldOutputType(AIFieldOutputType): + type = "choice" + baserow_field_type = SingleSelectFieldType + + def get_output_parser(self, ai_field): + choices = enum.Enum( + "Choices", + { + f"OPTION_{option.id}": option.value + for option in ai_field.select_options.all() + }, + ) + return StrictEnumOutputParser(enum=choices) + + def format_prompt(self, prompt, ai_field): + output_parser = self.get_output_parser(ai_field) + format_instructions = output_parser.get_format_instructions() + prompt = PromptTemplate( + template=prompt + "Given this user query: \n\n{format_instructions}", + input_variables=[], + partial_variables={"format_instructions": format_instructions}, + ) + message = prompt.format() + return message + + def parse_output(self, output, ai_field): + if not output: + return None + + output_parser = self.get_output_parser(ai_field) + try: + parsed_output = output_parser.parse(output) + except OutputParserException: + return None + select_option_id = int(parsed_output.name.split("_")[1]) + try: + return next( + o for o in ai_field.select_options.all() if o.id == select_option_id + ) + except StopIteration: + return None + + def prepare_data_sync_value(self, value, field, metadata): + try: + # The metadata contains a mapping of the select options where the key is the + # old ID and the value is the new ID. For some reason the key is converted + # to a string when moved into the JSON field. + return int(metadata["select_options_mapping"][str(value)]) + except (KeyError, TypeError): + return None diff --git a/premium/backend/src/baserow_premium/fields/field_converters.py b/premium/backend/src/baserow_premium/fields/field_converters.py index 7286ada9c..63ad645f7 100644 --- a/premium/backend/src/baserow_premium/fields/field_converters.py +++ b/premium/backend/src/baserow_premium/fields/field_converters.py @@ -1,5 +1,4 @@ from baserow.contrib.database.fields.field_converters import RecreateFieldConverter -from baserow.contrib.database.fields.models import LongTextField, TextField from .models import AIField @@ -10,5 +9,5 @@ class AIFieldConverter(RecreateFieldConverter): def is_applicable(self, from_model, from_field, to_field): from_ai = isinstance(from_field, AIField) to_ai = isinstance(to_field, AIField) - to_text_fields = isinstance(to_field, (TextField, LongTextField)) - return from_ai and not (to_text_fields or to_ai) or not from_ai and to_ai + # If any field converts to the AI field, then we want to recreate the field + return not from_ai and to_ai diff --git a/premium/backend/src/baserow_premium/fields/field_types.py b/premium/backend/src/baserow_premium/fields/field_types.py index 067002afa..23a3a060e 100644 --- a/premium/backend/src/baserow_premium/fields/field_types.py +++ b/premium/backend/src/baserow_premium/fields/field_types.py @@ -1,8 +1,8 @@ from typing import TYPE_CHECKING from django.contrib.auth import get_user_model -from django.db import IntegrityError, models -from django.db.models import Value +from django.db import IntegrityError +from django.utils.functional import lazy from baserow_premium.api.fields.exceptions import ( ERROR_GENERATIVE_AI_DOES_NOT_SUPPORT_FILE_FIELD, @@ -15,15 +15,13 @@ from baserow.api.generative_ai.errors import ( ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE, ) from baserow.contrib.database.api.fields.errors import ERROR_FIELD_DOES_NOT_EXIST -from baserow.contrib.database.fields.field_filters import ( - contains_filter, - contains_word_filter, +from baserow.contrib.database.fields.field_types import ( + CollationSortMixin, + SelectOptionBaseFieldType, ) -from baserow.contrib.database.fields.field_types import CollationSortMixin, TextField from baserow.contrib.database.fields.models import Field -from baserow.contrib.database.fields.registries import FieldType -from baserow.contrib.database.formula import BaserowFormulaTextType, BaserowFormulaType -from baserow.core.db import collate_expression +from baserow.contrib.database.fields.registries import field_type_registry +from baserow.contrib.database.formula import BaserowFormulaType from baserow.core.formula.serializers import FormulaSerializerField from baserow.core.generative_ai.exceptions import ( GenerativeAITypeDoesNotExist, @@ -35,6 +33,7 @@ from baserow.core.generative_ai.registries import ( ) from .models import AIField +from .registries import ai_field_output_registry from .visitors import replace_field_id_references User = get_user_model() @@ -43,7 +42,7 @@ if TYPE_CHECKING: from baserow.contrib.database.table.models import GeneratedTableModel -class AIFieldType(CollationSortMixin, FieldType): +class AIFieldType(CollationSortMixin, SelectOptionBaseFieldType): """ The AI field can automatically query a generative AI model based on the provided prompt. It's possible to reference other fields to generate a unique output. @@ -53,21 +52,29 @@ class AIFieldType(CollationSortMixin, FieldType): model_class = AIField can_be_in_form_view = False keep_data_on_duplication = True - allowed_fields = [ + allowed_fields = SelectOptionBaseFieldType.allowed_fields + [ "ai_generative_ai_type", "ai_generative_ai_model", + "ai_output_type", "ai_temperature", "ai_prompt", "ai_file_field_id", ] - serializer_field_names = [ + serializer_field_names = SelectOptionBaseFieldType.allowed_fields + [ "ai_generative_ai_type", "ai_generative_ai_model", + "ai_output_type", "ai_temperature", "ai_prompt", "ai_file_field_id", ] serializer_field_overrides = { + "ai_output_type": serializers.ChoiceField( + required=False, + choices=lazy(ai_field_output_registry.get_types, list)(), + help_text="The desired output type of the field. It will try to force the " + "response of the prompt to match it.", + ), "ai_temperature": serializers.FloatField( required=False, allow_null=True, @@ -89,6 +96,7 @@ class AIFieldType(CollationSortMixin, FieldType): allow_null=True, default=None, ), + **SelectOptionBaseFieldType.serializer_field_overrides, } api_exceptions_map = { GenerativeAITypeDoesNotExist: ERROR_GENERATIVE_AI_DOES_NOT_EXIST, @@ -96,52 +104,173 @@ class AIFieldType(CollationSortMixin, FieldType): GenerativeAITypeDoesNotSupportFileField: ERROR_GENERATIVE_AI_DOES_NOT_SUPPORT_FILE_FIELD, IntegrityError: ERROR_FIELD_DOES_NOT_EXIST, } - can_get_unique_values = False + can_get_unique_values = True + can_have_select_options = True + + def get_baserow_field_type(self, instance): + output_type = ai_field_output_registry.get(instance.ai_output_type) + baserow_field_type = field_type_registry.get_by_type( + output_type.baserow_field_type + ) + return baserow_field_type def get_serializer_field(self, instance, **kwargs): - required = kwargs.get("required", False) - return serializers.CharField( - **{ - "required": required, - "allow_null": not required, - "allow_blank": not required, - **kwargs, - } - ) + kwargs["read_only"] = True + baserow_field_type = self.get_baserow_field_type(instance) + return baserow_field_type.get_serializer_field(instance, **kwargs) + + def get_response_serializer_field(self, instance, **kwargs): + baserow_field_type = self.get_baserow_field_type(instance) + return baserow_field_type.get_response_serializer_field(instance, **kwargs) def get_model_field(self, instance, **kwargs): - return models.TextField(null=True, **kwargs) + baserow_field_type = self.get_baserow_field_type(instance) + return baserow_field_type.get_model_field(instance, **kwargs) def get_serializer_help_text(self, instance): return ( - "Holds a text value that is generated by a generative AI model using a " + "Holds a value that is generated by a generative AI model using a " "dynamic prompt." ) def random_value(self, instance, fake, cache): - return fake.name() + baserow_field_type = self.get_baserow_field_type(instance) + return baserow_field_type.random_value(instance, fake, cache) def to_baserow_formula_type(self, field) -> BaserowFormulaType: - return BaserowFormulaTextType(nullable=True) - - def from_baserow_formula_type( - self, formula_type: BaserowFormulaTextType - ) -> TextField: - return TextField() + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.to_baserow_formula_type(field) def get_value_for_filter(self, row: "GeneratedTableModel", field: Field) -> any: - value = getattr(row, field.db_column) - return collate_expression(Value(value)) + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.get_value_for_filter(row, field) - def contains_query(self, *args): - return contains_filter(*args) + def get_alter_column_prepare_old_value(self, connection, from_field, to_field): + baserow_field_type = self.get_baserow_field_type(from_field) + return baserow_field_type.get_alter_column_prepare_old_value( + connection, from_field, to_field + ) - def contains_word_query(self, *args): - return contains_word_filter(*args) + def get_alter_column_prepare_new_value(self, connection, from_field, to_field): + baserow_field_type = self.get_baserow_field_type(to_field) + return baserow_field_type.get_alter_column_prepare_new_value( + connection, from_field, to_field + ) + + def contains_query(self, field_name, value, model_field, field): + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.contains_query(field_name, value, model_field, field) + + def contains_word_query(self, field_name, value, model_field, field): + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.contains_word_query( + field_name, value, model_field, field + ) + + def check_can_order_by(self, field: Field) -> bool: + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.check_can_order_by(field) + + def check_can_group_by(self, field: Field) -> bool: + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.check_can_group_by(field) + + def get_search_expression(self, field: Field, queryset): + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.get_search_expression(field, queryset) + + def is_searchable(self, field): + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.is_searchable(field) + + def enhance_queryset(self, queryset, field, name, **kwargs): + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.enhance_queryset(queryset, field, name) + + def get_order(self, field, field_name, order_direction): + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.get_order(field, field_name, order_direction) + + def serialize_to_input_value(self, field, value): + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.serialize_to_input_value(field, value) + + def valid_for_bulk_update(self, field): + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.valid_for_bulk_update(field) + + def prepare_value_for_db(self, instance, value): + baserow_field_type = self.get_baserow_field_type(instance) + return baserow_field_type.prepare_value_for_db(instance, value) + + def prepare_value_for_db_in_bulk( + self, instance, values_by_row, continue_on_error=False + ): + baserow_field_type = self.get_baserow_field_type(instance) + return baserow_field_type.prepare_value_for_db_in_bulk( + instance, values_by_row, continue_on_error + ) + + def get_group_by_serializer_field(self, field, **kwargs): + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.get_group_by_serializer_field(field, **kwargs) + + def get_group_by_field_unique_value(self, field, field_name, value): + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.get_group_by_field_unique_value( + field, field_name, value + ) + + def get_group_by_field_filters_and_annotations( + self, field, field_name, base_queryset, value + ): + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.get_group_by_field_filters_and_annotations( + field, field_name, base_queryset, value + ) + + def get_export_serialized_value( + self, + row, + field_name, + cache, + files_zip=None, + storage=None, + ): + field_object = row.get_field_object(field_name) + baserow_field_type = self.get_baserow_field_type(field_object["field"]) + return baserow_field_type.get_export_serialized_value( + row, field_name, cache, files_zip, storage + ) + + def set_import_serialized_value( + self, + row, + field_name, + value, + id_mapping, + cache, + files_zip=None, + storage=None, + ): + field_object = row.get_field_object(field_name) + baserow_field_type = self.get_baserow_field_type(field_object["field"]) + return baserow_field_type.set_import_serialized_value( + row, field_name, value, id_mapping, cache, files_zip, storage + ) + + def get_export_value(self, value, field_object, rich_value=False): + baserow_field_type = self.get_baserow_field_type(field_object["field"]) + return baserow_field_type.get_export_value(value, field_object, rich_value) + + def get_human_readable_value(self, value, field_object): + baserow_field_type = self.get_baserow_field_type(field_object["field"]) + return baserow_field_type.get_human_readable_value(value, field_object) def _validate_field_kwargs( - self, ai_type, model_type, ai_file_field_id, workspace=None + self, ai_output_type, ai_type, model_type, ai_file_field_id, workspace=None ): + ai_field_output_registry.get(ai_output_type) ai_type = generative_ai_model_type_registry.get(ai_type) models = ai_type.get_enabled_models(workspace=workspace) if model_type not in models: @@ -154,12 +283,19 @@ class AIFieldType(CollationSortMixin, FieldType): def before_create( self, table, primary, allowed_field_values, order, user, field_kwargs ): + ai_output_type = field_kwargs.get( + "ai_output_type", AIField._meta.get_field("ai_output_type").default + ) ai_type = field_kwargs.get("ai_generative_ai_type", None) model_type = field_kwargs.get("ai_generative_ai_model", None) ai_file_field_id = field_kwargs.get("ai_file_field_id", None) workspace = table.database.workspace self._validate_field_kwargs( - ai_type, model_type, ai_file_field_id, workspace=workspace + ai_output_type, ai_type, model_type, ai_file_field_id, workspace=workspace + ) + + return super().before_create( + table, primary, allowed_field_values, order, user, field_kwargs ) def before_update(self, from_field, to_field_values, user, field_kwargs): @@ -167,6 +303,11 @@ class AIFieldType(CollationSortMixin, FieldType): if isinstance(from_field, AIField): update_field = from_field + ai_output_type = ( + field_kwargs.get("ai_output_type", None) + or getattr(update_field, "ai_output_type", None) + or AIField._meta.get_field("ai_output_type").default + ) ai_type = field_kwargs.get("ai_generative_ai_type", None) or getattr( update_field, "ai_generative_ai_type", None ) @@ -179,9 +320,11 @@ class AIFieldType(CollationSortMixin, FieldType): ai_file_field_id = getattr(update_field, "ai_file_field_id", None) workspace = from_field.table.database.workspace self._validate_field_kwargs( - ai_type, model_type, ai_file_field_id, workspace=workspace + ai_output_type, ai_type, model_type, ai_file_field_id, workspace=workspace ) + return super().before_update(from_field, to_field_values, user, field_kwargs) + def after_import_serialized( self, field: AIField, @@ -209,3 +352,17 @@ class AIFieldType(CollationSortMixin, FieldType): if save: field.save() + + def should_backup_field_data_for_same_type_update( + self, old_field, new_field_attrs + ) -> bool: + backup = super().should_backup_field_data_for_same_type_update( + old_field, new_field_attrs + ) + # Backup the field if the output type changes because + ai_output_changed = ( + "ai_output_type" in new_field_attrs + and new_field_attrs["ai_output_type"] + and new_field_attrs["ai_output_type"] != old_field.ai_output_type + ) + return backup or ai_output_changed diff --git a/premium/backend/src/baserow_premium/fields/models.py b/premium/backend/src/baserow_premium/fields/models.py index 27d69102c..76f7a64f0 100644 --- a/premium/backend/src/baserow_premium/fields/models.py +++ b/premium/backend/src/baserow_premium/fields/models.py @@ -3,12 +3,34 @@ from django.db import models from baserow.contrib.database.fields.models import Field from baserow.core.formula.field import FormulaField as ModelFormulaField +from .ai_field_output_types import TextAIFieldOutputType +from .registries import ai_field_output_registry + class AIField(Field): ai_generative_ai_type = models.CharField(max_length=32, null=True) ai_generative_ai_model = models.CharField(max_length=32, null=True) + ai_output_type = models.CharField( + max_length=32, + db_default=TextAIFieldOutputType.type, + default=TextAIFieldOutputType.type, + ) ai_temperature = models.FloatField(null=True) ai_prompt = ModelFormulaField(default="") ai_file_field = models.ForeignKey( Field, null=True, on_delete=models.SET_NULL, related_name="ai_field" ) + + def __getattr__(self, name): + """ + When a property is called on the field object, it tries to return the default + value of the field object related to the `ai_output_type` `model_class`. This + will make it more compatible with the check functions like `check_can_group_by`. + """ + + try: + ai_output_type = ai_field_output_registry.get(self.ai_output_type) + output_field = ai_output_type.baserow_field_type.model_class + return output_field._meta.get_field(name).default + except Exception: + super().__getattr__(name) diff --git a/premium/backend/src/baserow_premium/fields/registries.py b/premium/backend/src/baserow_premium/fields/registries.py new file mode 100644 index 000000000..940c827e7 --- /dev/null +++ b/premium/backend/src/baserow_premium/fields/registries.py @@ -0,0 +1,70 @@ +import abc +import typing +from typing import Any + +from baserow.contrib.database.fields.models import Field +from baserow.core.registry import Instance, Registry + +if typing.TYPE_CHECKING: + from baserow_premium.fields.models import AIField + + +class AIFieldOutputType(abc.ABC, Instance): + @property + @abc.abstractmethod + def baserow_field_type(self) -> str: + """ + The Baserow field type that corresponds to this AI output type and should be + used to do various Baserow operations like filtering, sorting, etc. + """ + + def format_prompt(self, prompt: str, ai_field: "AIField"): + """ + Hook that can be used to change and format the provided prompt for this output + type. It accepts the original already resolved prompt and should return the + updated one. + + It can be used to include the format instructions of an output parser, for + example. + + :param prompt: The resolved prompt provided by the user. This already contains + the resolved variables. + :param ai_field: The AI field related to the output type. + :return: Should return the formatted prompt. This can include additional + information that can change the outcome of the prompt. + """ + + return prompt + + def parse_output(self, output: str, ai_field: "AIField"): + """ + Hook that can be used to parse the output of the generative AI prompt. If an + output parser formatting instructions are added in `format_prompt`, then this + hook can be used to parse it. + + :param output: The text output of the generative AI. + :param ai_field: The AI field related to the output type. + :return: Should return the parsed output. + """ + + return output + + def prepare_data_sync_value(self, value: Any, field: Field, metadata: dict) -> Any: + """ + Hook that's called when preparing the value in the local Baserow data sync. + It's for example used to map the value of the single select option. + + :param value: The original value. + :param field: The field that's synced. + :param metadata: The metadata related to the datasync property. + :return: The updated value. + """ + + return value + + +class AIFieldOutputRegistry(Registry): + name = "ai_field_output" + + +ai_field_output_registry: AIFieldOutputRegistry = AIFieldOutputRegistry() diff --git a/premium/backend/src/baserow_premium/fields/tasks.py b/premium/backend/src/baserow_premium/fields/tasks.py index 7de0aa073..e3006e786 100644 --- a/premium/backend/src/baserow_premium/fields/tasks.py +++ b/premium/backend/src/baserow_premium/fields/tasks.py @@ -20,6 +20,7 @@ from baserow.core.handler import CoreHandler from baserow.core.user.handler import User from .models import AIField +from .registries import ai_field_output_registry @app.task(bind=True, queue="export") @@ -28,9 +29,9 @@ def generate_ai_values_for_rows(self, user_id: int, field_id: int, row_ids: list ai_field = FieldHandler().get_field( field_id, - base_queryset=AIField.objects.all().select_related( - "table__database__workspace" - ), + base_queryset=AIField.objects.all() + .select_related("table__database__workspace") + .prefetch_related("select_options"), ) table = ai_field.table workspace = table.database.workspace @@ -71,6 +72,8 @@ def generate_ai_values_for_rows(self, user_id: int, field_id: int, row_ids: list ) raise exc + ai_output_type = ai_field_output_registry.get(ai_field.ai_output_type) + for i, row in enumerate(rows): context = HumanReadableRowContext(row, exclude_field_ids=[ai_field.id]) message = str( @@ -79,6 +82,11 @@ def generate_ai_values_for_rows(self, user_id: int, field_id: int, row_ids: list ) ) + # The AI output type should be able to format the prompt because it can add + # additional instructions to it. The choice output type for example adds + # additional prompt trying to force the out, for example. + message = ai_output_type.format_prompt(message, ai_field) + try: if ai_field.ai_file_field_id is not None and isinstance( generative_ai_model_type, GenerativeAIWithFilesModelType @@ -105,6 +113,12 @@ def generate_ai_values_for_rows(self, user_id: int, field_id: int, row_ids: list workspace=workspace, temperature=ai_field.ai_temperature, ) + + # Because the AI output type can change the prompt to try to force the + # output a certain way, then it should give the opportunity to parse the + # output when it's given. With the choice output type, it will try to match + # it to a `SelectOption`, for example. + value = ai_output_type.parse_output(value, ai_field) except Exception as exc: # If the prompt fails once, we should not continue with the other rows. rows_ai_values_generation_error.send( diff --git a/premium/backend/src/baserow_premium/migrations/0023_aifield_ai_output_type.py b/premium/backend/src/baserow_premium/migrations/0023_aifield_ai_output_type.py new file mode 100644 index 000000000..89ed9a56e --- /dev/null +++ b/premium/backend/src/baserow_premium/migrations/0023_aifield_ai_output_type.py @@ -0,0 +1,17 @@ +# Generated by Django 5.0.9 on 2024-10-28 13:30 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("baserow_premium", "0022_aifield_ai_temperature"), + ] + + operations = [ + migrations.AddField( + model_name="aifield", + name="ai_output_type", + field=models.CharField(db_default="text", default="text", max_length=32), + ), + ] diff --git a/premium/backend/tests/baserow_premium_tests/export/test_premium_export_types.py b/premium/backend/tests/baserow_premium_tests/export/test_premium_export_types.py index 2271dd97e..d7694f51d 100644 --- a/premium/backend/tests/baserow_premium_tests/export/test_premium_export_types.py +++ b/premium/backend/tests/baserow_premium_tests/export/test_premium_export_types.py @@ -111,7 +111,8 @@ def test_can_export_every_interesting_different_field_to_json( "uuid": "00000000-0000-4000-8000-000000000001", "autonumber": 1, "password": "", - "ai": "" + "ai": "", + "ai_choice": "" }, { "id": 2, @@ -230,7 +231,8 @@ def test_can_export_every_interesting_different_field_to_json( "uuid": "00000000-0000-4000-8000-000000000002", "autonumber": 2, "password": true, - "ai": "I'm an AI." + "ai": "I'm an AI.", + "ai_choice": "Object" } ] """ @@ -393,6 +395,7 @@ def test_can_export_every_interesting_different_field_to_xml( <autonumber>1</autonumber> <password/> <ai/> + <ai-choice/> </row> <row> <id>2</id> @@ -512,6 +515,7 @@ def test_can_export_every_interesting_different_field_to_xml( <autonumber>2</autonumber> <password>true</password> <ai>I'm an AI.</ai> + <ai-choice>Object</ai-choice> </row> </rows> """ diff --git a/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_models.py b/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_models.py new file mode 100644 index 000000000..a2fec82f4 --- /dev/null +++ b/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_models.py @@ -0,0 +1,11 @@ +import pytest + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_dynamic_get_attr(premium_data_fixture): + field = premium_data_fixture.create_ai_field(ai_output_type="text") + assert field.long_text_enable_rich_text is False + + with pytest.raises(AttributeError): + field.non_existing_property diff --git a/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_output_types.py b/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_output_types.py new file mode 100644 index 000000000..bc07fbd89 --- /dev/null +++ b/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_output_types.py @@ -0,0 +1,109 @@ +import enum + +import pytest +from baserow_premium.fields.ai_field_output_types import StrictEnumOutputParser +from baserow_premium.fields.tasks import generate_ai_values_for_rows +from langchain_core.prompts import PromptTemplate + +from baserow.core.generative_ai.registries import ( + GenerativeAIModelType, + generative_ai_model_type_registry, +) + + +def test_strict_enum_output_parser(): + choices = enum.Enum( + "Choices", + { + "OPTION_1": "Object", + "OPTION_2": "Animal", + "OPTION_3": "Human", + "OPTION_4": "A,B,C", + }, + ) + output_parser = StrictEnumOutputParser(enum=choices) + format_instructions = output_parser.get_format_instructions() + prompt = "What is a motorcycle?" + prompt = PromptTemplate( + template=prompt + "\n\n{format_instructions}", + input_variables=[], + partial_variables={"format_instructions": format_instructions}, + ) + message = prompt.format() + + assert '["Object", "Animal", "Human", "A,B,C"]' in message + + assert output_parser.parse("Object") == choices.OPTION_1 + assert output_parser.parse("Animal") == choices.OPTION_2 + assert output_parser.parse("Human") == choices.OPTION_3 + assert output_parser.parse("A,B,C") == choices.OPTION_4 + + assert output_parser.parse(" Object ") == choices.OPTION_1 + assert output_parser.parse("'Object'") == choices.OPTION_1 + assert output_parser.parse("'A'") == choices.OPTION_4 + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_choice_output_type(premium_data_fixture, api_client): + class TestAIChoiceOutputTypeGenerativeAIModelType(GenerativeAIModelType): + type = "test_ai_choice_ouput_type" + i = 0 + + def is_enabled(self, workspace=None): + return True + + def get_enabled_models(self, workspace=None): + return ["test_1"] + + def prompt(self, model, prompt, workspace=None, temperature=None): + self.i += 1 + if self.i == 1: + # Existing option should be matches based on the string. + return "Object" + else: + return "Else" + + def get_settings_serializer(self): + return None + + generative_ai_model_type_registry.register( + TestAIChoiceOutputTypeGenerativeAIModelType() + ) + + premium_data_fixture.register_fake_generate_ai_type() + user = premium_data_fixture.create_user() + database = premium_data_fixture.create_database_application( + user=user, name="Placeholder" + ) + table = premium_data_fixture.create_database_table( + name="Example", database=database + ) + field = premium_data_fixture.create_ai_field( + table=table, + order=0, + name="ai", + ai_output_type="choice", + ai_generative_ai_type="test_ai_choice_ouput_type", + ai_generative_ai_model="test_1", + ai_prompt="'Option'", + ) + option_1 = premium_data_fixture.create_select_option( + field=field, value="Object", color="red" + ) + option_2 = premium_data_fixture.create_select_option( + field=field, value="Else", color="red" + ) + premium_data_fixture.create_select_option(field=field, value="Animal", color="blue") + + model = table.get_model() + row_1 = model.objects.create() + row_2 = model.objects.create() + + generate_ai_values_for_rows(user.id, field.id, [row_1.id, row_2.id]) + + row_1.refresh_from_db() + row_2.refresh_from_db() + + assert getattr(row_1, f"field_{field.id}").id == option_1.id + assert getattr(row_2, f"field_{field.id}").id == option_2.id diff --git a/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_type.py b/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_type.py index e04951ab6..ac08c2742 100644 --- a/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_type.py +++ b/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_type.py @@ -1,10 +1,12 @@ from django.shortcuts import reverse import pytest +from baserow_premium.fields.field_types import AIFieldType from baserow_premium.fields.models import AIField from rest_framework.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND from baserow.contrib.database.fields.handler import FieldHandler +from baserow.contrib.database.fields.registries import field_type_registry from baserow.contrib.database.table.handler import TableHandler from baserow.core.db import specific_iterator @@ -28,6 +30,7 @@ def test_create_ai_field_type(premium_data_fixture): ai_prompt="'Who are you?'", ) + assert ai_field.ai_output_type == "text" # default value assert ai_field.ai_generative_ai_type == "test_generative_ai" assert ai_field.ai_generative_ai_model == "test_1" assert ai_field.ai_prompt == "'Who are you?'" @@ -51,6 +54,7 @@ def test_update_ai_field_type(premium_data_fixture): ai_prompt="'Who are you?'", ) + assert ai_field.ai_output_type == "text" # default value assert ai_field.ai_generative_ai_type == "test_generative_ai" assert ai_field.ai_generative_ai_model == "test_1" assert ai_field.ai_prompt == "'Who are you?'" @@ -98,6 +102,68 @@ def test_create_ai_field_type_via_api(premium_data_fixture, api_client): ) response_json = response.json() assert response.status_code == HTTP_200_OK + assert response_json["ai_output_type"] == "text" + assert response_json["ai_generative_ai_type"] == "test_generative_ai" + assert response_json["ai_generative_ai_model"] == "test_1" + assert response_json["ai_prompt"] == "'Who are you?'" + assert response_json["ai_temperature"] is None + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_create_ai_field_type_via_api_with_non_existing_ai_output_type( + premium_data_fixture, api_client +): + user, token = premium_data_fixture.create_user_and_token() + table = premium_data_fixture.create_database_table(user=user) + premium_data_fixture.register_fake_generate_ai_type() + premium_data_fixture.create_text_field(table=table, order=1, name="name") + + response = api_client.post( + reverse("api:database:fields:list", kwargs={"table_id": table.id}), + { + "name": "Test 1", + "type": "ai", + "ai_output_type": "DOES_NOT_EXIST", + "ai_generative_ai_type": "test_generative_ai", + "ai_generative_ai_model": "test_1", + "ai_prompt": "'Who are you?'", + }, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + response_json = response.json() + assert response.status_code == HTTP_400_BAD_REQUEST + assert response_json["error"] == "ERROR_REQUEST_BODY_VALIDATION" + assert response_json["detail"]["ai_output_type"][0]["code"] == "invalid_choice" + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_create_ai_field_type_via_api_with_ai_output_type( + premium_data_fixture, api_client +): + user, token = premium_data_fixture.create_user_and_token() + table = premium_data_fixture.create_database_table(user=user) + premium_data_fixture.register_fake_generate_ai_type() + premium_data_fixture.create_text_field(table=table, order=1, name="name") + + response = api_client.post( + reverse("api:database:fields:list", kwargs={"table_id": table.id}), + { + "name": "Test 1", + "type": "ai", + "ai_output_type": "text", + "ai_generative_ai_type": "test_generative_ai", + "ai_generative_ai_model": "test_1", + "ai_prompt": "'Who are you?'", + }, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + response_json = response.json() + assert response.status_code == HTTP_200_OK + assert response_json["ai_output_type"] == "text" assert response_json["ai_generative_ai_type"] == "test_generative_ai" assert response_json["ai_generative_ai_model"] == "test_1" assert response_json["ai_prompt"] == "'Who are you?'" @@ -191,9 +257,6 @@ def test_update_ai_field_temperature_none_via_api(premium_data_fixture, api_clie field = premium_data_fixture.create_ai_field( table=table, order=1, name="name", ai_temperature=0.7 ) - file_field = premium_data_fixture.create_file_field( - table=table, order=2, name="file" - ) response = api_client.patch( reverse("api:database:fields:item", kwargs={"field_id": field.id}), @@ -209,6 +272,121 @@ def test_update_ai_field_temperature_none_via_api(premium_data_fixture, api_clie assert response.json()["ai_temperature"] is None +@pytest.mark.django_db +@pytest.mark.field_ai +def test_update_ai_field_via_api_invalid_output_type(premium_data_fixture, api_client): + user, token = premium_data_fixture.create_user_and_token() + table = premium_data_fixture.create_database_table(user=user) + premium_data_fixture.register_fake_generate_ai_type() + field = premium_data_fixture.create_ai_field( + table=table, order=1, name="name", ai_temperature=0.7 + ) + + response = api_client.patch( + reverse("api:database:fields:item", kwargs={"field_id": field.id}), + { + "ai_output_type": "INVALID_CHOICE", + "ai_generative_ai_type": "test_generative_ai_with_files", + "ai_generative_ai_model": "test_1", + "ai_temperature": None, + }, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + response_json = response.json() + assert response.status_code == HTTP_400_BAD_REQUEST + assert response_json["error"] == "ERROR_REQUEST_BODY_VALIDATION" + assert response_json["detail"]["ai_output_type"][0]["code"] == "invalid_choice" + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_update_ai_field_via_api_valid_output_type(premium_data_fixture, api_client): + user, token = premium_data_fixture.create_user_and_token() + table = premium_data_fixture.create_database_table(user=user) + premium_data_fixture.register_fake_generate_ai_type() + field = premium_data_fixture.create_ai_field( + table=table, order=1, name="name", ai_temperature=0.7 + ) + + response = api_client.patch( + reverse("api:database:fields:item", kwargs={"field_id": field.id}), + { + "ai_output_type": "text", + "ai_generative_ai_type": "test_generative_ai", + "ai_generative_ai_model": "test_1", + "ai_temperature": None, + "ai_prompt": "'Who are you?'", + }, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + response_json = response.json() + assert response.status_code == HTTP_200_OK + assert response_json["ai_output_type"] == "text" + assert response_json["ai_generative_ai_type"] == "test_generative_ai" + assert response_json["ai_generative_ai_model"] == "test_1" + assert response_json["ai_prompt"] == "'Who are you?'" + assert response_json["ai_temperature"] is None + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_update_to_ai_field_with_all_parameters(premium_data_fixture, api_client): + user, token = premium_data_fixture.create_user_and_token() + table = premium_data_fixture.create_database_table(user=user) + premium_data_fixture.register_fake_generate_ai_type() + field = premium_data_fixture.create_text_field(table=table, order=1, name="name") + + response = api_client.patch( + reverse("api:database:fields:item", kwargs={"field_id": field.id}), + { + "type": "ai", + "ai_output_type": "text", + "ai_generative_ai_type": "test_generative_ai", + "ai_generative_ai_model": "test_1", + "ai_temperature": None, + "ai_prompt": "'Who are you?'", + }, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + response_json = response.json() + assert response.status_code == HTTP_200_OK + assert response_json["ai_output_type"] == "text" + assert response_json["ai_generative_ai_type"] == "test_generative_ai" + assert response_json["ai_generative_ai_model"] == "test_1" + assert response_json["ai_prompt"] == "'Who are you?'" + assert response_json["ai_temperature"] is None + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_update_to_ai_field_without_parameters(premium_data_fixture, api_client): + user, token = premium_data_fixture.create_user_and_token() + table = premium_data_fixture.create_database_table(user=user) + premium_data_fixture.register_fake_generate_ai_type() + field = premium_data_fixture.create_text_field(table=table, order=1, name="name") + + response = api_client.patch( + reverse("api:database:fields:item", kwargs={"field_id": field.id}), + { + "type": "ai", + "ai_generative_ai_type": "test_generative_ai", + "ai_generative_ai_model": "test_1", + }, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + response_json = response.json() + assert response.status_code == HTTP_200_OK + assert response_json["ai_output_type"] == "text" + assert response_json["ai_generative_ai_type"] == "test_generative_ai" + assert response_json["ai_generative_ai_model"] == "test_1" + assert response_json["ai_prompt"] == "" + assert response_json["ai_temperature"] is None + + @pytest.mark.django_db @pytest.mark.field_ai def test_create_ai_field_type_via_api_invalid_formula(premium_data_fixture, api_client): @@ -606,3 +784,266 @@ def test_duplicate_table_with_ai_field_broken_references(premium_data_fixture): duplicated_ai_field = duplicated_fields[2] assert duplicated_ai_field.ai_prompt == f"concat('test:',get('fields.field_0'))" + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_can_set_select_options_to_choice_ai_output_type( + premium_data_fixture, api_client +): + user, token = premium_data_fixture.create_user_and_token() + table = premium_data_fixture.create_database_table(user=user) + premium_data_fixture.register_fake_generate_ai_type() + premium_data_fixture.create_text_field(table=table, order=1, name="name") + + response = api_client.post( + reverse("api:database:fields:list", kwargs={"table_id": table.id}), + { + "name": "Test 1", + "type": "ai", + "ai_output_type": "choice", + "ai_generative_ai_type": "test_generative_ai", + "ai_generative_ai_model": "test_1", + "ai_prompt": "'Who are you?'", + "select_options": [ + {"value": "Small", "color": "red"}, + {"value": "Medium", "color": "blue"}, + {"value": "Large", "color": "green"}, + ], + }, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + response_json = response.json() + + assert response.status_code == HTTP_200_OK + assert response_json["ai_output_type"] == "choice" + assert response_json["ai_generative_ai_type"] == "test_generative_ai" + assert response_json["ai_generative_ai_model"] == "test_1" + assert response_json["ai_prompt"] == "'Who are you?'" + assert len(response_json["select_options"]) == 3 + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_should_backup(premium_data_fixture, api_client): + ai_field_type = field_type_registry.get(AIFieldType.type) + ai_field = premium_data_fixture.create_ai_field() + file_field = premium_data_fixture.create_file_field(table=ai_field.table) + + assert ( + ai_field_type.should_backup_field_data_for_same_type_update( + ai_field, + { + "ai_generative_ai_type": "test_generative_ai_2", + "ai_generative_ai_model": "test_model_2", + "ai_prompt": "'New AI prompt'", + "ai_output_type": "text", # same as before + "ai_temperature": 1, + "ai_file_field": file_field, + }, + ) + is False + ) + + assert ( + ai_field_type.should_backup_field_data_for_same_type_update( + ai_field, + { + "ai_generative_ai_type": "test_generative_ai_2", + "ai_generative_ai_model": "test_model_2", + "ai_prompt": "'New AI prompt'", + "ai_output_type": "choice", # new one + "ai_temperature": 1, + "ai_file_field": file_field, + }, + ) + is True + ) # Expect to make a backup when output type changes. + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_can_convert_from_text_output_type_to_choice_output_type( + premium_data_fixture, api_client +): + premium_data_fixture.register_fake_generate_ai_type() + user = premium_data_fixture.create_user() + database = premium_data_fixture.create_database_application( + user=user, name="Placeholder" + ) + table = premium_data_fixture.create_database_table( + name="Example", database=database + ) + field = premium_data_fixture.create_ai_field( + table=table, order=0, name="ai", ai_output_type="text" + ) + + model = table.get_model() + model.objects.create(**{f"field_{field.id}": "Option 1"}) + model.objects.create(**{f"field_{field.id}": "Something else"}) + + field = FieldHandler().update_field( + user=user, + field=field, + ai_output_type="choice", + select_options=[ + {"value": "Option 1", "color": "red"}, + ], + ) + + table.refresh_from_db() + model = table.get_model() + rows = list(model.objects.all()) + + # Converting text ai field to choice field should try to convert the text values to + # the new choices. + assert getattr(rows[0], f"field_{field.id}").value == "Option 1" + assert getattr(rows[1], f"field_{field.id}") is None + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_can_convert_from_choice_output_type_to_text_output_type( + premium_data_fixture, api_client +): + premium_data_fixture.register_fake_generate_ai_type() + user = premium_data_fixture.create_user() + database = premium_data_fixture.create_database_application( + user=user, name="Placeholder" + ) + table = premium_data_fixture.create_database_table( + name="Example", database=database + ) + field = premium_data_fixture.create_ai_field( + table=table, order=0, name="ai", ai_output_type="choice" + ) + select_option = premium_data_fixture.create_select_option( + field=field, value="Option 1", color="blue", order=0 + ) + + model = table.get_model() + model.objects.create(**{f"field_{field.id}_id": select_option.id}) + model.objects.create(**{f"field_{field.id}": None}) + + field = FieldHandler().update_field( + user=user, + field=field, + ai_output_type="text", + ) + + table.refresh_from_db() + model = table.get_model() + rows = list(model.objects.all()) + + # Converting choice ai field to text ai field should try to convert the choices to + # text values. + assert getattr(rows[0], f"field_{field.id}") == "Option 1" + assert getattr(rows[1], f"field_{field.id}") is None + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_can_convert_from_text_field_to_text_output_type( + premium_data_fixture, api_client +): + premium_data_fixture.register_fake_generate_ai_type() + user = premium_data_fixture.create_user() + database = premium_data_fixture.create_database_application( + user=user, name="Placeholder" + ) + table = premium_data_fixture.create_database_table( + name="Example", database=database + ) + field = premium_data_fixture.create_text_field(table=table, order=0, name="text") + + model = table.get_model() + model.objects.create(**{f"field_{field.id}": "Test"}) + + field = FieldHandler().update_field( + user=user, + field=field, + new_type_name="ai", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_prompt="'test'", + ai_output_type="text", + ) + + table.refresh_from_db() + model = table.get_model() + rows = list(model.objects.all()) + + # Expect the value to be reset because we don't want to keep the existing cell + # value when converting from any other field. + assert getattr(rows[0], f"field_{field.id}") is None + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_can_convert_from_text_field_to_choice_output_type( + premium_data_fixture, api_client +): + premium_data_fixture.register_fake_generate_ai_type() + user = premium_data_fixture.create_user() + database = premium_data_fixture.create_database_application( + user=user, name="Placeholder" + ) + table = premium_data_fixture.create_database_table( + name="Example", database=database + ) + field = premium_data_fixture.create_text_field(table=table, order=0, name="text") + + model = table.get_model() + model.objects.create(**{f"field_{field.id}": "Test"}) + + field = FieldHandler().update_field( + user=user, + field=field, + new_type_name="ai", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_prompt="'test'", + ai_output_type="choice", + ) + + table.refresh_from_db() + model = table.get_model() + rows = list(model.objects.all()) + + # Expect the value to be reset because we don't want to keep the existing cell + # value when converting from any other field. + assert getattr(rows[0], f"field_{field.id}") is None + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_can_convert_from_text_output_type_to_text_field( + premium_data_fixture, api_client +): + premium_data_fixture.register_fake_generate_ai_type() + user = premium_data_fixture.create_user() + database = premium_data_fixture.create_database_application( + user=user, name="Placeholder" + ) + table = premium_data_fixture.create_database_table( + name="Example", database=database + ) + field = premium_data_fixture.create_ai_field(table=table, order=0, name="ai") + + model = table.get_model() + model.objects.create(**{f"field_{field.id}": "Test"}) + + field = FieldHandler().update_field( + user=user, + field=field, + new_type_name="text", + ) + + table.refresh_from_db() + model = table.get_model() + rows = list(model.objects.all()) + + # Converting text ai field to text field should keep the values because the text + # field conversion is automatically used. + assert getattr(rows[0], f"field_{field.id}") == "Test" diff --git a/premium/web-frontend/modules/baserow_premium/aiFieldOutputTypes.js b/premium/web-frontend/modules/baserow_premium/aiFieldOutputTypes.js new file mode 100644 index 000000000..8302700da --- /dev/null +++ b/premium/web-frontend/modules/baserow_premium/aiFieldOutputTypes.js @@ -0,0 +1,92 @@ +import { Registerable } from '@baserow/modules/core/registry' + +import { + LongTextFieldType, + SingleSelectFieldType, +} from '@baserow/modules/database/fieldTypes' +import FieldSelectOptionsSubForm from '@baserow/modules/database/components/field/FieldSelectOptionsSubForm.vue' + +export class AIFieldOutputType extends Registerable { + /** + * A human readable name of the AI output type. This will be shown in in the dropdown + * where the user chosen the type. + */ + getName() { + return null + } + + /** + * A human-readable description of the AI output type. + */ + getDescription() { + return null + } + + constructor(...args) { + super(...args) + this.type = this.getType() + + if (this.type === null) { + throw new Error('The type name of an admin type must be set.') + } + if (this.name === null) { + throw new Error('The name of an admin type must be set.') + } + } + + getBaserowFieldType() { + throw new Error('The Baserow field type must be set on the AI output type.') + } + + /** + * Can optionally return a form component that will be added to the `FieldForm` is + * the output type is chosen. + */ + getFormComponent() { + return null + } +} + +export class TextAIFieldOutputType extends AIFieldOutputType { + static getType() { + return 'text' + } + + getName() { + const { i18n } = this.app + return i18n.t('aiOutputType.text') + } + + getDescription() { + const { i18n } = this.app + return i18n.t('aiOutputType.textDescription') + } + + getBaserowFieldType() { + return this.app.$registry.get('field', LongTextFieldType.getType()) + } +} + +export class ChoiceAIFieldOutputType extends AIFieldOutputType { + static getType() { + return 'choice' + } + + getName() { + const { i18n } = this.app + return i18n.t('aiOutputType.choice') + } + + getDescription() { + const { i18n } = this.app + return i18n.t('aiOutputType.choiceDescription') + } + + getBaserowFieldType() { + return this.app.$registry.get('field', SingleSelectFieldType.getType()) + } + + getFormComponent() { + return FieldSelectOptionsSubForm + } +} diff --git a/premium/web-frontend/modules/baserow_premium/components/field/FieldAISubForm.vue b/premium/web-frontend/modules/baserow_premium/components/field/FieldAISubForm.vue index 0a94bab30..3fe21cb7e 100644 --- a/premium/web-frontend/modules/baserow_premium/components/field/FieldAISubForm.vue +++ b/premium/web-frontend/modules/baserow_premium/components/field/FieldAISubForm.vue @@ -35,6 +35,31 @@ /> </Dropdown> </FormGroup> + + <FormGroup + required + small-label + :label="$t('fieldAISubForm.outputType')" + :help-icon-tooltip="$t('fieldAISubForm.outputTypeTooltip')" + > + <Dropdown + v-model="values.ai_output_type" + class="dropdown--floating" + :fixed-items="true" + > + <DropdownItem + v-for="outputType in outputTypes" + :key="outputType.getType()" + :name="outputType.getName()" + :value="outputType.getType()" + :description="outputType.getDescription()" + /> + </Dropdown> + <template v-if="changedOutputType" #warning> + {{ $t('fieldAISubForm.outputTypeChangedWarning') }} + </template> + </FormGroup> + <FormGroup small-label :label="$t('fieldAISubForm.prompt')" @@ -52,6 +77,12 @@ </div> <template #error> {{ $t('error.requiredField') }}</template> </FormGroup> + + <component + :is="outputType.getFormComponent()" + ref="childForm" + v-bind="$props" + /> </div> <div v-else> <p> @@ -67,6 +98,7 @@ import form from '@baserow/modules/core/mixins/form' import fieldSubForm from '@baserow/modules/database/mixins/fieldSubForm' import FormulaInputField from '@baserow/modules/core/components/formula/FormulaInputField' import SelectAIModelForm from '@baserow/modules/core/components/ai/SelectAIModelForm' +import { TextAIFieldOutputType } from '@baserow_premium/aiFieldOutputTypes' export default { name: 'FieldAISubForm', @@ -74,9 +106,10 @@ export default { mixins: [form, fieldSubForm], data() { return { - allowedValues: ['ai_prompt', 'ai_file_field_id'], + allowedValues: ['ai_prompt', 'ai_file_field_id', 'ai_output_type'], values: { ai_prompt: '', + ai_output_type: TextAIFieldOutputType.getType(), ai_file_field_id: null, }, fileFieldSupported: false, @@ -113,6 +146,19 @@ export default { return t.canRepresentFiles(field) }) }, + outputTypes() { + return Object.values(this.$registry.getAll('aiFieldOutputType')) + }, + outputType() { + return this.$registry.get('aiFieldOutputType', this.values.ai_output_type) + }, + changedOutputType() { + return ( + this.defaultValues.id && + this.defaultValues.type === this.values.type && + this.defaultValues.ai_output_type !== this.values.ai_output_type + ) + }, }, methods: { setFileFieldSupported(generativeAIType) { diff --git a/premium/web-frontend/modules/baserow_premium/components/row/RowEditFieldAI.vue b/premium/web-frontend/modules/baserow_premium/components/row/RowEditFieldAI.vue index 012388c1d..d04b57d20 100644 --- a/premium/web-frontend/modules/baserow_premium/components/row/RowEditFieldAI.vue +++ b/premium/web-frontend/modules/baserow_premium/components/row/RowEditFieldAI.vue @@ -1,15 +1,12 @@ <template> - <div class="control__elements"> - <FormTextarea - ref="input" - v-model="value" - type="text" - class="margin-bottom-2" - :rows="6" - :disabled="true" - /> - - <template v-if="!readOnly"> + <div> + <component + :is="outputRowEditFieldComponent" + ref="field" + v-bind="$props" + :read-only="true" + ></component> + <div v-if="!readOnly" class="margin-top-2"> <Button v-if="isDeactivated && rowIsCreated" type="secondary" @@ -19,7 +16,7 @@ {{ $t('rowEditFieldAI.generate') }} </Button> <Button - v-if="rowIsCreated" + v-else-if="rowIsCreated" type="secondary" :loading="generating" @click="generate()" @@ -33,7 +30,7 @@ :workspace="workspace" :name="fieldName" ></component> - </template> + </div> </div> </template> @@ -48,6 +45,12 @@ export default { fieldName() { return this.$registry.get('field', this.field.type).getName() }, + outputRowEditFieldComponent() { + return this.$registry + .get('aiFieldOutputType', this.field.ai_output_type) + .getBaserowFieldType() + .getRowEditFieldComponent(this.field) + }, }, } </script> diff --git a/premium/web-frontend/modules/baserow_premium/components/views/grid/fields/FunctionalGridViewFieldAI.vue b/premium/web-frontend/modules/baserow_premium/components/views/grid/fields/FunctionalGridViewFieldAI.vue index 09b398886..1276ba8a7 100644 --- a/premium/web-frontend/modules/baserow_premium/components/views/grid/fields/FunctionalGridViewFieldAI.vue +++ b/premium/web-frontend/modules/baserow_premium/components/views/grid/fields/FunctionalGridViewFieldAI.vue @@ -23,9 +23,15 @@ </Button> </div> </div> - <div v-else class="grid-view__cell grid-field-long-text__cell"> - <div class="grid-field-long-text">{{ props.value }}</div> - </div> + <component + :is="$options.methods.getFunctionalOutputFieldComponent(parent, props)" + v-else + :workspace-id="props.workspaceId" + :field="props.field" + :value="props.value" + :state="props.state" + :read-only="props.readOnly" + /> </template> <script> @@ -41,6 +47,12 @@ export default { .get('field', AIFieldType.getType()) .isDeactivated(props.workspaceId) }, + getFunctionalOutputFieldComponent(parent, props) { + return parent.$registry + .get('aiFieldOutputType', props.field.ai_output_type) + .getBaserowFieldType() + .getFunctionalGridViewFieldComponent(props.field) + }, }, } </script> diff --git a/premium/web-frontend/modules/baserow_premium/components/views/grid/fields/GridViewFieldAI.vue b/premium/web-frontend/modules/baserow_premium/components/views/grid/fields/GridViewFieldAI.vue index c5067dbeb..bb04c636c 100644 --- a/premium/web-frontend/modules/baserow_premium/components/views/grid/fields/GridViewFieldAI.vue +++ b/premium/web-frontend/modules/baserow_premium/components/views/grid/fields/GridViewFieldAI.vue @@ -18,40 +18,34 @@ </Button> </div> </div> - <div + <component + :is="outputGridViewFieldComponent" v-else ref="cell" - class="grid-view__cell grid-field-long-text__cell active" - :class="{ editing: opened }" - @keyup.enter="opened = true" + v-bind="$props" + :read-only="true" > - <div v-if="!opened" class="grid-field-long-text">{{ value }}</div> - <template v-else> - <div class="grid-field-long-text__textarea"> - {{ value }} - </div> + <template v-if="!readOnly && editing" #default="{ editing }"> <div style="background-color: #fff; padding: 8px"> - <template v-if="!readOnly"> - <ButtonText - v-if="!isDeactivated" - icon="iconoir-magic-wand" - :disabled="!modelAvailable || generating" - :loading="generating" - @click.prevent.stop="generate()" - > - {{ $t('gridViewFieldAI.regenerate') }} - </ButtonText> - <ButtonText - v-else - icon="iconoir-lock" - @click.prevent.stop="$refs.clickModal.show()" - > - {{ $t('gridViewFieldAI.regenerate') }} - </ButtonText> - </template> + <ButtonText + v-if="!isDeactivated" + icon="iconoir-magic-wand" + :disabled="!modelAvailable || generating" + :loading="generating" + @click.prevent.stop="generate()" + > + {{ $t('gridViewFieldAI.regenerate') }} + </ButtonText> + <ButtonText + v-else + icon="iconoir-lock" + @click.prevent.stop="$refs.clickModal.show()" + > + {{ $t('gridViewFieldAI.regenerate') }} + </ButtonText> </div> </template> - </div> + </component> <component :is="deactivatedClickComponent" v-if="isDeactivated && workspace" @@ -75,6 +69,12 @@ export default { fieldName() { return this.$registry.get('field', this.field.type).getName() }, + outputGridViewFieldComponent() { + return this.$registry + .get('aiFieldOutputType', this.field.ai_output_type) + .getBaserowFieldType() + .getGridViewFieldComponent(this.field) + }, }, methods: { save() { diff --git a/premium/web-frontend/modules/baserow_premium/fieldTypes.js b/premium/web-frontend/modules/baserow_premium/fieldTypes.js index 08055e655..83be13a27 100644 --- a/premium/web-frontend/modules/baserow_premium/fieldTypes.js +++ b/premium/web-frontend/modules/baserow_premium/fieldTypes.js @@ -2,13 +2,6 @@ import { FieldType, FormulaFieldType, } from '@baserow/modules/database/fieldTypes' -import RowHistoryFieldText from '@baserow/modules/database/components/row/RowHistoryFieldText' -import RowCardFieldText from '@baserow/modules/database/components/card/RowCardFieldText' -import { collatedStringCompare } from '@baserow/modules/core/utils/string' -import { - genericContainsFilter, - genericContainsWordFilter, -} from '@baserow/modules/database/utils/fieldFilters' import GridViewFieldAI from '@baserow_premium/components/views/grid/fields/GridViewFieldAI' import FunctionalGridViewFieldAI from '@baserow_premium/components/views/grid/fields/FunctionalGridViewFieldAI' @@ -33,6 +26,10 @@ export class AIFieldType extends FieldType { return i18n.t('premiumFieldType.ai') } + getIsReadOnly() { + return true + } + getGridViewFieldComponent() { return GridViewFieldAI } @@ -45,18 +42,21 @@ export class AIFieldType extends FieldType { return RowEditFieldAI } - getCardComponent() { - return RowCardFieldText - } - - getRowHistoryEntryComponent() { - return RowHistoryFieldText - } - getFormComponent() { return FieldAISubForm } + getCardComponent(field) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getCardComponent(field) + } + + getRowHistoryEntryComponent(field) { + return null + } + getFormViewFieldComponents(field) { return {} } @@ -65,17 +65,32 @@ export class AIFieldType extends FieldType { return null } - getSort(name, order) { - return (a, b) => { - const stringA = a[name] === null ? '' : '' + a[name] - const stringB = b[name] === null ? '' : '' + b[name] + getCardValueHeight(field) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getCardValueHeight(field) + } - return collatedStringCompare(stringA, stringB, order) - } + getSort(name, order, field) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getSort(name, order, field) + } + + getCanSortInView(field) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getCanSortInView(field) } getDocsDataType(field) { - return 'string' + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getDocsDataType(field) } getDocsDescription(field) { @@ -84,15 +99,147 @@ export class AIFieldType extends FieldType { } getDocsRequestExample(field) { - return 'string' + return 'read only' } - getContainsFilterFunction() { - return genericContainsFilter + getDocsResponseExample(field) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getDocsResponseExample(field) + } + + prepareValueForCopy(field, value) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .prepareValueForCopy(field, value) + } + + getContainsFilterFunction(field) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getContainsFilterFunction(field) } getContainsWordFilterFunction(field) { - return genericContainsWordFilter + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getContainsWordFilterFunction(field) + } + + toHumanReadableString(field, value) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .toHumanReadableString(field, value) + } + + getSortIndicator(field, registry) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getSortIndicator(field, registry) + } + + canRepresentDate(field) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .canRepresentDate(field) + } + + getCanGroupByInView(field) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getCanGroupByInView(field) + } + + parseInputValue(field, value) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .parseInputValue(field, value) + } + + canRepresentFiles(field) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .canRepresentFiles(field) + } + + getHasEmptyValueFilterFunction(field) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getHasEmptyValueFilterFunction(field) + } + + getHasValueContainsFilterFunction(field) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getHasValueContainsFilterFunction(field) + } + + getHasValueContainsWordFilterFunction(field) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getHasValueContainsWordFilterFunction(field) + } + + getHasValueLengthIsLowerThanFilterFunction(field) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getHasValueLengthIsLowerThanFilterFunction(field) + } + + getGroupByComponent(field) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getGroupByComponent(field) + } + + getGroupByIndicator(field, registry) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getGroupByIndicator(field, registry) + } + + getRowValueFromGroupValue(field, value) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getRowValueFromGroupValue(field, value) + } + + getGroupValueFromRowValue(field, value) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .getGroupValueFromRowValue(field, value) + } + + isEqual(field, value1, value2) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .isEqual(field, value1, value2) + } + + canBeReferencedByFormulaField(field) { + return this.app.$registry + .get('aiFieldOutputType', field.ai_output_type) + .getBaserowFieldType() + .canBeReferencedByFormulaField(field) } getGridViewContextItemsOnCellsSelection(field) { diff --git a/premium/web-frontend/modules/baserow_premium/locales/en.json b/premium/web-frontend/modules/baserow_premium/locales/en.json index 2de930e6c..59b6a3cef 100644 --- a/premium/web-frontend/modules/baserow_premium/locales/en.json +++ b/premium/web-frontend/modules/baserow_premium/locales/en.json @@ -409,7 +409,10 @@ "promptPlaceholder": "What is Baserow?", "premiumFeature": "The AI field is a premium feature", "emptyFileField": "None", - "fileFieldHelp": "The first compatible file in the field will be used as the knowledge base for the prompt. The file has to be a text file with the supported file extension like .txt, .md, .pdf, .docx." + "fileFieldHelp": "The first compatible file in the field will be used as the knowledge base for the prompt. The file has to be a text file with the supported file extension like .txt, .md, .pdf, .docx.", + "outputType": "Output type", + "outputTypeTooltip": "Select your desired output format to guide the LLM in generating responses that match your desired options.", + "outputTypeChangedWarning": "Clears the generated cell values." }, "rowEditFieldAI": { "generate": "Generate", @@ -426,5 +429,11 @@ "formulaFieldAI": { "generateWithAI": "Generate using AI", "featureName": "Generate formula using AI" + }, + "aiOutputType": { + "text": "Text", + "textDescription": "Generates free text based on the prompt.", + "choice": "Choice", + "choiceDescription": "Chooses only one of the field options." } } diff --git a/premium/web-frontend/modules/baserow_premium/plugin.js b/premium/web-frontend/modules/baserow_premium/plugin.js index 29788c5b0..7176f3647 100644 --- a/premium/web-frontend/modules/baserow_premium/plugin.js +++ b/premium/web-frontend/modules/baserow_premium/plugin.js @@ -52,6 +52,10 @@ import { AIFieldType, PremiumFormulaFieldType, } from '@baserow_premium/fieldTypes' +import { + ChoiceAIFieldOutputType, + TextAIFieldOutputType, +} from '@baserow_premium/aiFieldOutputTypes' export default (context) => { const { store, app, isDev } = context @@ -94,6 +98,8 @@ export default (context) => { store.registerModule('template/view/timeline', timelineStore) store.registerModule('impersonating', impersonatingStore) + app.$registry.registerNamespace('aiFieldOutputType') + app.$registry.register('plugin', new PremiumPlugin(context)) app.$registry.register('admin', new DashboardType(context)) app.$registry.register('admin', new UsersAdminType(context)) @@ -160,4 +166,13 @@ export default (context) => { 'rowModalSidebar', new CommentsRowModalSidebarType(context) ) + + app.$registry.register( + 'aiFieldOutputType', + new TextAIFieldOutputType(context) + ) + app.$registry.register( + 'aiFieldOutputType', + new ChoiceAIFieldOutputType(context) + ) } diff --git a/premium/web-frontend/modules/baserow_premium/viewTypes.js b/premium/web-frontend/modules/baserow_premium/viewTypes.js index 61f294726..b015f3cba 100644 --- a/premium/web-frontend/modules/baserow_premium/viewTypes.js +++ b/premium/web-frontend/modules/baserow_premium/viewTypes.js @@ -180,6 +180,7 @@ export class KanbanViewType extends PremiumViewType { row, values, metadata, + updatedFieldIds, storePrefix = '' ) { if (this.isCurrentView(store, tableId)) { @@ -386,6 +387,7 @@ export class CalendarViewType extends PremiumViewType { row, values, metadata, + updatedFieldIds, storePrefix = '' ) { if (this.isCurrentView(store, tableId)) { diff --git a/web-frontend/modules/core/assets/scss/components/views/grid/single_select.scss b/web-frontend/modules/core/assets/scss/components/views/grid/single_select.scss index a45411811..9f2e60c09 100644 --- a/web-frontend/modules/core/assets/scss/components/views/grid/single_select.scss +++ b/web-frontend/modules/core/assets/scss/components/views/grid/single_select.scss @@ -1,3 +1,13 @@ +.grid-field-single-select__cell { + &.active { + bottom: auto; + right: auto; + height: auto; + min-height: calc(100% + 4px); + min-width: calc(100% + 4px); + } +} + .grid-field-single-select { position: relative; display: block; diff --git a/web-frontend/modules/database/components/field/FieldFormulaSubForm.vue b/web-frontend/modules/database/components/field/FieldFormulaSubForm.vue index 5126d0308..7ddae8d05 100644 --- a/web-frontend/modules/database/components/field/FieldFormulaSubForm.vue +++ b/web-frontend/modules/database/components/field/FieldFormulaSubForm.vue @@ -84,7 +84,7 @@ export default { const isNotThisField = f.id !== this.defaultValues.id const canBeReferencedByFormulaField = this.$registry .get('field', f.type) - .canBeReferencedByFormulaField() + .canBeReferencedByFormulaField(f) return isNotThisField && canBeReferencedByFormulaField }) }, diff --git a/web-frontend/modules/database/components/field/FieldSelectTargetFieldSubForm.vue b/web-frontend/modules/database/components/field/FieldSelectTargetFieldSubForm.vue index f557225f1..4bb92fbc7 100644 --- a/web-frontend/modules/database/components/field/FieldSelectTargetFieldSubForm.vue +++ b/web-frontend/modules/database/components/field/FieldSelectTargetFieldSubForm.vue @@ -114,7 +114,7 @@ export default { .filter((f) => { return this.$registry .get('field', f.type) - .canBeReferencedByFormulaField() + .canBeReferencedByFormulaField(f) }) .filter((f) => { return this.$hasPermission( diff --git a/web-frontend/modules/database/components/view/grid/fields/FunctionalGridViewFieldSingleSelect.vue b/web-frontend/modules/database/components/view/grid/fields/FunctionalGridViewFieldSingleSelect.vue index 7050a5bc6..385d9f408 100644 --- a/web-frontend/modules/database/components/view/grid/fields/FunctionalGridViewFieldSingleSelect.vue +++ b/web-frontend/modules/database/components/view/grid/fields/FunctionalGridViewFieldSingleSelect.vue @@ -1,5 +1,9 @@ <template functional> - <div ref="cell" class="grid-view__cell" :class="data.staticClass || ''"> + <div + ref="cell" + class="grid-view__cell grid-field-single-select__cell" + :class="data.staticClass || ''" + > <div class="grid-field-single-select"> <div v-if="props.value" diff --git a/web-frontend/modules/database/components/view/grid/fields/GridViewFieldLongText.vue b/web-frontend/modules/database/components/view/grid/fields/GridViewFieldLongText.vue index ad6d24271..a8d8117c6 100644 --- a/web-frontend/modules/database/components/view/grid/fields/GridViewFieldLongText.vue +++ b/web-frontend/modules/database/components/view/grid/fields/GridViewFieldLongText.vue @@ -15,6 +15,7 @@ class="grid-field-long-text__textarea" /> <div v-else class="grid-field-long-text__textarea">{{ value }}</div> + <slot name="default" :slot-props="{ editing, opened }"></slot> </div> </template> diff --git a/web-frontend/modules/database/components/view/grid/fields/GridViewFieldSingleSelect.vue b/web-frontend/modules/database/components/view/grid/fields/GridViewFieldSingleSelect.vue index 86bdbf1a2..cffc744cd 100644 --- a/web-frontend/modules/database/components/view/grid/fields/GridViewFieldSingleSelect.vue +++ b/web-frontend/modules/database/components/view/grid/fields/GridViewFieldSingleSelect.vue @@ -1,5 +1,5 @@ <template> - <div ref="cell" class="grid-view__cell active"> + <div ref="cell" class="grid-view__cell grid-field-single-select__cell active"> <div ref="dropdownLink" class="grid-field-single-select grid-field-single-select--selected" @@ -18,6 +18,7 @@ class="iconoir-nav-arrow-down grid-field-single-select__icon" ></i> </div> + <slot name="default" :slot-props="{ editing, opened: true }"></slot> <FieldSelectOptionsDropdown v-if="!readOnly" ref="dropdown" diff --git a/web-frontend/modules/database/fieldTypes.js b/web-frontend/modules/database/fieldTypes.js index 7350df98c..4886632ac 100644 --- a/web-frontend/modules/database/fieldTypes.js +++ b/web-frontend/modules/database/fieldTypes.js @@ -727,7 +727,7 @@ export class FieldType extends Registerable { * Override and return true if the field type can be referenced by a formula field. * @return {boolean} */ - canBeReferencedByFormulaField() { + canBeReferencedByFormulaField(field) { return false } diff --git a/web-frontend/modules/database/realtime.js b/web-frontend/modules/database/realtime.js index b2c5ccce2..78ec98c7e 100644 --- a/web-frontend/modules/database/realtime.js +++ b/web-frontend/modules/database/realtime.js @@ -202,6 +202,7 @@ export const registerRealtimeEvents = (realtime) => { rowBeforeUpdate, row, data.metadata[row.id], + data.updated_field_ids, 'page/' ) } diff --git a/web-frontend/modules/database/store/view/grid.js b/web-frontend/modules/database/store/view/grid.js index 7faa5e385..01c02328f 100644 --- a/web-frontend/modules/database/store/view/grid.js +++ b/web-frontend/modules/database/store/view/grid.js @@ -2506,7 +2506,7 @@ export const actions = { */ async updatedExistingRow( { commit, getters, dispatch }, - { view, fields, row, values, metadata } + { view, fields, row, values, metadata, updatedFieldIds = [] } ) { const oldRow = clone(row) const newRow = Object.assign(clone(row), values) @@ -2646,15 +2646,21 @@ export const actions = { // sure the loading state will stop if the value is updated. This is done even // if the row is not found in the buffer because it could have been removed from // the buffer when scrolling outside the buffer range. - const updatedFieldIds = Object.entries(values) + const getFieldId = (key) => parseInt(key.split('_')[1]) + const fieldIdsToClearPendingOperationsFor = Object.entries(values) .filter( ([key, value]) => - key.startsWith('field_') && !_.isEqual(value, oldRow[key]) + key.startsWith('field_') && + // Either the value has changed. + (_.isEqual(value, oldRow[key]) || + // Or the backend has just recalculated the value, even if it hasn't + // actually changed. + updatedFieldIds.includes(getFieldId(key))) ) - .map(([key, value]) => parseInt(key.split('_')[1])) + .map(([key, value]) => getFieldId(key)) commit('CLEAR_PENDING_FIELD_OPERATIONS', { - fieldIds: updatedFieldIds, + fieldIds: fieldIdsToClearPendingOperationsFor, rowId: row.id, }) diff --git a/web-frontend/modules/database/viewTypes.js b/web-frontend/modules/database/viewTypes.js index c56f82e99..b131f73a9 100644 --- a/web-frontend/modules/database/viewTypes.js +++ b/web-frontend/modules/database/viewTypes.js @@ -267,7 +267,16 @@ export class ViewType extends Registerable { * via a real time event by another user. It can be used to check if data in an store * needs to be updated. */ - rowUpdated(context, tableId, fields, row, values, metadata, storePrefix) {} + rowUpdated( + context, + tableId, + fields, + row, + values, + metadata, + updatedFieldIds, + storePrefix + ) {} /** * Event that is called when something went wrong while generating AI values @@ -722,6 +731,7 @@ export class GridViewType extends ViewType { row, values, metadata, + updatedFieldIds, storePrefix = '' ) { if (this.isCurrentView(store, tableId)) { @@ -731,6 +741,7 @@ export class GridViewType extends ViewType { row, values, metadata, + updatedFieldIds, }) await store.dispatch(storePrefix + 'view/grid/fetchByScrollTopDelayed', { scrollTop: store.getters[storePrefix + 'view/grid/getScrollTop'], @@ -971,6 +982,7 @@ export const BaseBufferedRowViewTypeMixin = (Base) => row, values, metadata, + updatedFieldIds, storePrefix = '' ) { if (this.isCurrentView(store, tableId)) {