mirror of
https://gitlab.com/bramw/baserow.git
synced 2025-04-04 13:15:24 +00:00
Merge branch 'ai-field-classification' into 'develop'
[2/2] AI improvements: AI field choice output type Closes #3143 See merge request baserow/baserow!2825
This commit is contained in:
commit
4e29d1afba
47 changed files with 1533 additions and 146 deletions
backend
src/baserow
contrib/database
test_utils
tests/baserow/contrib
database
api
field
import_export
ws
integrations/local_baserow
changelog/entries/unreleased/feature
enterprise/backend
src/baserow_enterprise/data_sync
tests/baserow_enterprise_tests/data_sync
premium
backend
web-frontend/modules/baserow_premium
web-frontend/modules
core/assets/scss/components/views/grid
database
|
@ -254,8 +254,20 @@ def construct_all_possible_field_kwargs(
|
|||
"name": "ai",
|
||||
"ai_generative_ai_type": "test_generative_ai",
|
||||
"ai_generative_ai_model": "test_1",
|
||||
"ai_output_type": "text",
|
||||
"ai_prompt": "'Who are you?'",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "ai_choice",
|
||||
"ai_generative_ai_type": "test_generative_ai",
|
||||
"ai_generative_ai_model": "test_1",
|
||||
"ai_prompt": "'What are you?'",
|
||||
"ai_output_type": "choice",
|
||||
"select_options": [
|
||||
{"id": 5, "value": "Object", "color": "orange"},
|
||||
{"id": 6, "value": "Else", "color": "yellow"},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
# If you have added a new field please add an entry into the dict above with any
|
||||
|
|
|
@ -975,7 +975,7 @@ class FieldType(
|
|||
field_id = serialized_copy.pop("id")
|
||||
serialized_copy.pop("type")
|
||||
select_options = (
|
||||
serialized_copy.pop("select_options")
|
||||
serialized_copy.pop("select_options", [])
|
||||
if self.can_have_select_options
|
||||
else []
|
||||
)
|
||||
|
|
|
@ -1976,6 +1976,8 @@ class ViewHandler(metaclass=baserow_trace_methods(tracer)):
|
|||
:return: The created view sort instance.
|
||||
"""
|
||||
|
||||
field = field.specific
|
||||
|
||||
workspace = view.table.database.workspace
|
||||
CoreHandler().check_permissions(
|
||||
user, ReadFieldOperationType.type, workspace=workspace, context=field
|
||||
|
|
|
@ -274,6 +274,7 @@ def public_rows_updated(
|
|||
table_id=PUBLIC_PLACEHOLDER_ENTITY_ID,
|
||||
serialized_rows_before_update=visible_fields_only_old_rows,
|
||||
serialized_rows=visible_fields_only_updated_rows,
|
||||
updated_field_ids=list(updated_field_ids),
|
||||
metadata={},
|
||||
),
|
||||
slug=public_view.slug,
|
||||
|
|
|
@ -81,6 +81,9 @@ def rows_updated(
|
|||
serialized_rows=get_row_serializer_class(
|
||||
model, RowSerializer, is_response=True
|
||||
)(rows, many=True).data,
|
||||
# Broadcast a list of updated fields so that the listener can take
|
||||
# action even if the value didn't change.
|
||||
updated_field_ids=list(updated_field_ids),
|
||||
metadata=row_metadata_registry.generate_and_merge_metadata_for_rows(
|
||||
user, table, [row.id for row in rows]
|
||||
),
|
||||
|
@ -213,6 +216,7 @@ class RealtimeRowMessages:
|
|||
serialized_rows_before_update: List[Dict[str, Any]],
|
||||
serialized_rows: List[Dict[str, Any]],
|
||||
metadata: Dict[int, Dict[str, Any]],
|
||||
updated_field_ids: List[int],
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "rows_updated",
|
||||
|
@ -223,6 +227,7 @@ class RealtimeRowMessages:
|
|||
"rows_before_update": serialized_rows_before_update,
|
||||
"rows": serialized_rows,
|
||||
"metadata": metadata,
|
||||
"updated_field_ids": updated_field_ids,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -250,6 +250,9 @@ def setup_interesting_test_table(
|
|||
"phone_number": "+4412345678",
|
||||
"password": "test",
|
||||
"ai": "I'm an AI.",
|
||||
"ai_choice": SelectOption.objects.get(
|
||||
value="Object", field_id=name_to_field_id["ai_choice"]
|
||||
).id,
|
||||
}
|
||||
|
||||
with freeze_time("2020-02-01 01:23"):
|
||||
|
|
|
@ -378,6 +378,11 @@ def test_get_row_serializer_with_user_field_names(data_fixture):
|
|||
"autonumber": 2,
|
||||
"password": True,
|
||||
"ai": "I'm an AI.",
|
||||
"ai_choice": {
|
||||
"color": "orange",
|
||||
"id": SelectOption.objects.get(value="Object").id,
|
||||
"value": "Object",
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
|
|
@ -84,6 +84,8 @@ def test_serialize_group_by_metadata_on_all_fields_in_interesting_table(data_fix
|
|||
|
||||
actual_result_per_field_name = {}
|
||||
|
||||
ai_choice_select_options = Field.objects.get(name="ai_choice").select_options.all()
|
||||
|
||||
for field in fields_to_group_by:
|
||||
counts = handler.get_group_by_metadata_in_rows([field], rows, queryset)
|
||||
serialized = serialize_group_by_metadata(counts)[field.db_column]
|
||||
|
@ -258,4 +260,12 @@ def test_serialize_group_by_metadata_on_all_fields_in_interesting_table(data_fix
|
|||
{"count": 1, "field_duration_dhms": 90066.0},
|
||||
{"count": 1, "field_duration_dhms": None},
|
||||
],
|
||||
"ai": [
|
||||
{"count": 1, "field_ai": "I'm an AI."},
|
||||
{"count": 1, "field_ai": None},
|
||||
],
|
||||
"ai_choice": [
|
||||
{"count": 1, "field_ai_choice": ai_choice_select_options[0].id},
|
||||
{"count": 1, "field_ai_choice": None},
|
||||
],
|
||||
}
|
||||
|
|
|
@ -303,7 +303,9 @@ def test_link_row_field_type_with_text_values(data_fixture):
|
|||
for field_type in [
|
||||
f
|
||||
for f in field_type_registry.get_all()
|
||||
if f.can_get_unique_values and not f.read_only
|
||||
# The AI field is not compatible because it some field kwargs are required and
|
||||
# not passed in.
|
||||
if f.can_get_unique_values and not f.read_only and f.type != "ai"
|
||||
]:
|
||||
field_type_name = field_type.type
|
||||
field_name = f"Field {field_type_name}"
|
||||
|
|
|
@ -247,12 +247,12 @@ def test_can_export_every_interesting_different_field_to_csv(
|
|||
"phone_number,formula_text,formula_int,formula_bool,formula_decimal,formula_dateinterval,"
|
||||
"formula_date,formula_singleselect,formula_email,formula_link_with_label,"
|
||||
"formula_link_url_only,formula_multipleselect,count,rollup,duration_rollup_sum,"
|
||||
"duration_rollup_avg,lookup,uuid,autonumber,password,ai\r\n"
|
||||
"duration_rollup_avg,lookup,uuid,autonumber,password,ai,ai_choice\r\n"
|
||||
"1,,,,,,,,,0,False,,,,,,,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,"
|
||||
"02/01/2021 13:00,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,02/01/2021 13:00,"
|
||||
"user@example.com,user@example.com,,,,,,,,,,,,,,,,,,,test FORMULA,1,True,33.3333333333,"
|
||||
"1d 0:00,2020-01-01,,,label (https://google.com),https://google.com,,0,0.000,"
|
||||
"0:00,0:00,,00000000-0000-4000-8000-000000000001,1,,\r\n"
|
||||
"0:00,0:00,,00000000-0000-4000-8000-000000000001,1,,,\r\n"
|
||||
"2,text,long_text,https://www.google.com,test@example.com,-1,1,-1.2,1.2,3,True,"
|
||||
"02/01/2020 01:23,02/01/2020,01/02/2020 01:23,01/02/2020,01/02/2020 02:23,"
|
||||
"01/02/2020 02:23,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,"
|
||||
|
@ -267,7 +267,7 @@ def test_can_export_every_interesting_different_field_to_csv(
|
|||
'"user2@example.com,user3@example.com",\'+4412345678,test FORMULA,1,True,33.3333333333,'
|
||||
"1d 0:00,2020-01-01,A,test@example.com,label (https://google.com),https://google.com,"
|
||||
'"C,D,E",3,-122.222,0:04,0:02,"linked_row_1,linked_row_2,",'
|
||||
"00000000-0000-4000-8000-000000000002,2,True,I'm an AI.\r\n"
|
||||
"00000000-0000-4000-8000-000000000002,2,True,I'm an AI.,Object\r\n"
|
||||
)
|
||||
|
||||
assert contents == expected
|
||||
|
|
|
@ -1014,6 +1014,7 @@ def test_batch_update_rows_some_not_visible_in_public_view_to_be_visible_event_s
|
|||
},
|
||||
],
|
||||
"metadata": {},
|
||||
"updated_field_ids": [hidden_field.id],
|
||||
},
|
||||
None,
|
||||
None,
|
||||
|
@ -1151,6 +1152,7 @@ def test_batch_update_rows_visible_in_public_view_to_some_not_be_visible_event_s
|
|||
},
|
||||
],
|
||||
"metadata": {},
|
||||
"updated_field_ids": [hidden_field.id],
|
||||
},
|
||||
None,
|
||||
None,
|
||||
|
@ -1454,6 +1456,7 @@ def test_given_row_visible_in_public_view_when_updated_to_still_be_visible_event
|
|||
}
|
||||
],
|
||||
"metadata": {},
|
||||
"updated_field_ids": [visible_field.id, hidden_field.id],
|
||||
},
|
||||
None,
|
||||
None,
|
||||
|
@ -1580,6 +1583,7 @@ def test_batch_update_rows_visible_in_public_view_still_be_visible_event_sent(
|
|||
},
|
||||
],
|
||||
"metadata": {},
|
||||
"updated_field_ids": [visible_field.id, hidden_field.id],
|
||||
},
|
||||
None,
|
||||
None,
|
||||
|
@ -1661,6 +1665,7 @@ def test_batch_update_subset_rows_visible_in_public_view_no_filters(
|
|||
},
|
||||
],
|
||||
"metadata": {},
|
||||
"updated_field_ids": [visible_field.id],
|
||||
},
|
||||
None,
|
||||
None,
|
||||
|
@ -2021,6 +2026,7 @@ def test_given_row_visible_in_public_view_when_moved_row_updated_sent(
|
|||
}
|
||||
],
|
||||
"metadata": {},
|
||||
"updated_field_ids": [],
|
||||
},
|
||||
None,
|
||||
None,
|
||||
|
|
|
@ -315,6 +315,7 @@ def test_rows_history_updated(mock_broadcast_channel_group, data_fixture):
|
|||
),
|
||||
],
|
||||
"metadata": {},
|
||||
"updated_field_ids": [field.id],
|
||||
},
|
||||
None,
|
||||
None,
|
||||
|
|
|
@ -864,6 +864,21 @@ def test_local_baserow_table_service_generate_schema_with_interesting_test_table
|
|||
"metadata": {},
|
||||
"type": "string",
|
||||
},
|
||||
field_db_column_by_name["ai_choice"]: {
|
||||
"title": "ai_choice",
|
||||
"default": None,
|
||||
"searchable": True,
|
||||
"sortable": True,
|
||||
"filterable": False,
|
||||
"original_type": "ai",
|
||||
"metadata": {},
|
||||
"properties": {
|
||||
"color": {"title": "color", "type": "string"},
|
||||
"id": {"title": "id", "type": "number"},
|
||||
"value": {"title": "value", "type": "string"},
|
||||
},
|
||||
"type": "object",
|
||||
},
|
||||
"id": {
|
||||
"type": "number",
|
||||
"title": "Id",
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "feature",
|
||||
"message": "AI choice output type for classification purposes.",
|
||||
"issue_number": 3143,
|
||||
"bullet_points": [],
|
||||
"created_at": "2024-11-10"
|
||||
}
|
|
@ -5,6 +5,7 @@ from uuid import UUID
|
|||
from django.db.models import Prefetch
|
||||
|
||||
from baserow_premium.fields.field_types import AIFieldType
|
||||
from baserow_premium.fields.registries import ai_field_output_registry
|
||||
from baserow_premium.license.handler import LicenseHandler
|
||||
from rest_framework import serializers
|
||||
|
||||
|
@ -33,7 +34,6 @@ from baserow.contrib.database.fields.field_types import (
|
|||
from baserow.contrib.database.fields.models import (
|
||||
DateField,
|
||||
Field,
|
||||
LongTextField,
|
||||
NumberField,
|
||||
SelectOption,
|
||||
TextField,
|
||||
|
@ -51,16 +51,21 @@ from baserow_enterprise.features import DATA_SYNC
|
|||
from .models import LocalBaserowTableDataSync
|
||||
|
||||
|
||||
def prepare_single_select_value(value, enabled_property):
|
||||
def prepare_single_select_value(value, field, metadata):
|
||||
try:
|
||||
# The metadata contains a mapping of the select options where the key is the
|
||||
# old ID and the value is the new ID. For some reason the key is converted to
|
||||
# a string when moved into the JSON field.
|
||||
return enabled_property.metadata["select_options_mapping"][str(value)]
|
||||
return metadata["select_options_mapping"][str(value)]
|
||||
except (KeyError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def prepare_ai_value(value, field, metadata):
|
||||
output_type = ai_field_output_registry.get(field.ai_output_type)
|
||||
return output_type.prepare_data_sync_value(value, field, metadata)
|
||||
|
||||
|
||||
# List of Baserow supported field types that can be included in the data sync.
|
||||
supported_field_types = {
|
||||
TextFieldType.type: {},
|
||||
|
@ -78,7 +83,7 @@ supported_field_types = {
|
|||
LastModifiedFieldType.type: {},
|
||||
UUIDFieldType.type: {},
|
||||
AutonumberFieldType.type: {},
|
||||
AIFieldType.type: {},
|
||||
AIFieldType.type: {"prepare_value": prepare_ai_value},
|
||||
SingleSelectFieldType.type: {"prepare_value": prepare_single_select_value},
|
||||
}
|
||||
|
||||
|
@ -99,7 +104,6 @@ class BaserowFieldDataSyncProperty(DataSyncProperty):
|
|||
LastModifiedFieldType.type: DateField,
|
||||
UUIDFieldType.type: TextField,
|
||||
AutonumberFieldType.type: NumberField,
|
||||
AIFieldType.type: LongTextField,
|
||||
}
|
||||
|
||||
def __init__(self, field, immutable_properties, **kwargs):
|
||||
|
@ -332,7 +336,9 @@ class LocalBaserowTableDataSyncType(DataSyncType):
|
|||
if "prepare_value" in supported_field:
|
||||
for row in rows_queryset:
|
||||
row[enabled_property.key] = supported_field["prepare_value"](
|
||||
row[enabled_property.key], enabled_property
|
||||
row[enabled_property.key],
|
||||
enabled_property.field,
|
||||
enabled_property.metadata,
|
||||
)
|
||||
progress.increment(by=2) # makes the total `10`
|
||||
|
||||
|
|
|
@ -389,6 +389,7 @@ def test_sync_data_sync_table_with_interesting_table_as_source(enterprise_data_f
|
|||
"last_modified_datetime_eu_tzone": "02/01/2021 13:00",
|
||||
"autonumber": "1",
|
||||
"ai": "",
|
||||
"ai_choice": "",
|
||||
"uuid": "00000000-0000-4000-8000-000000000001",
|
||||
}
|
||||
assert results == {
|
||||
|
@ -432,6 +433,7 @@ def test_sync_data_sync_table_with_interesting_table_as_source(enterprise_data_f
|
|||
"last_modified_datetime_eu_tzone": "02/01/2021 13:00",
|
||||
"autonumber": "2",
|
||||
"ai": "I'm an AI.",
|
||||
"ai_choice": "Object",
|
||||
"uuid": "00000000-0000-4000-8000-000000000002",
|
||||
}
|
||||
|
||||
|
|
|
@ -23,13 +23,21 @@ class BaserowPremiumConfig(AppConfig):
|
|||
)
|
||||
|
||||
from .fields.actions import GenerateFormulaWithAIActionType
|
||||
from .fields.ai_field_output_types import (
|
||||
ChoiceAIFieldOutputType,
|
||||
TextAIFieldOutputType,
|
||||
)
|
||||
from .fields.field_converters import AIFieldConverter
|
||||
from .fields.field_types import AIFieldType
|
||||
from .fields.registries import ai_field_output_registry
|
||||
|
||||
field_type_registry.register(AIFieldType())
|
||||
|
||||
field_converter_registry.register(AIFieldConverter())
|
||||
|
||||
ai_field_output_registry.register(TextAIFieldOutputType())
|
||||
ai_field_output_registry.register(ChoiceAIFieldOutputType())
|
||||
|
||||
from baserow.contrib.database.rows.registries import row_metadata_registry
|
||||
from baserow.contrib.database.views.registries import (
|
||||
decorator_type_registry,
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
import enum
|
||||
import json
|
||||
from difflib import get_close_matches
|
||||
from typing import Any
|
||||
|
||||
from langchain.output_parsers.enum import EnumOutputParser
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from baserow.contrib.database.fields.field_types import (
|
||||
LongTextFieldType,
|
||||
SingleSelectFieldType,
|
||||
)
|
||||
|
||||
from .registries import AIFieldOutputType
|
||||
|
||||
|
||||
class TextAIFieldOutputType(AIFieldOutputType):
|
||||
type = "text"
|
||||
baserow_field_type = LongTextFieldType
|
||||
|
||||
|
||||
class StrictEnumOutputParser(EnumOutputParser):
|
||||
def get_format_instructions(self) -> str:
|
||||
json_array = json.dumps(self._valid_values)
|
||||
return f"""Categorize the result following these requirements:
|
||||
|
||||
- Select only one option from the JSON array below.
|
||||
- Don't use quotes or commas or partial values, just the option name.
|
||||
- Choose the option that most closely matches the row values.
|
||||
|
||||
```json
|
||||
{json_array}
|
||||
```""" # nosec this falsely marks as hardcoded sql expression, but it's not related
|
||||
# to SQL at all.
|
||||
|
||||
def parse(self, response: str) -> Any:
|
||||
response = response.strip()
|
||||
# Sometimes the LLM responds with a quotes value or with part of the value if
|
||||
# it contains a comma. Finding the close matches helps with selecting the
|
||||
# right value.
|
||||
closest_matches = get_close_matches(
|
||||
response, self._valid_values, n=1, cutoff=0.0
|
||||
)
|
||||
return super().parse(closest_matches[0])
|
||||
|
||||
|
||||
class ChoiceAIFieldOutputType(AIFieldOutputType):
|
||||
type = "choice"
|
||||
baserow_field_type = SingleSelectFieldType
|
||||
|
||||
def get_output_parser(self, ai_field):
|
||||
choices = enum.Enum(
|
||||
"Choices",
|
||||
{
|
||||
f"OPTION_{option.id}": option.value
|
||||
for option in ai_field.select_options.all()
|
||||
},
|
||||
)
|
||||
return StrictEnumOutputParser(enum=choices)
|
||||
|
||||
def format_prompt(self, prompt, ai_field):
|
||||
output_parser = self.get_output_parser(ai_field)
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
prompt = PromptTemplate(
|
||||
template=prompt + "Given this user query: \n\n{format_instructions}",
|
||||
input_variables=[],
|
||||
partial_variables={"format_instructions": format_instructions},
|
||||
)
|
||||
message = prompt.format()
|
||||
return message
|
||||
|
||||
def parse_output(self, output, ai_field):
|
||||
if not output:
|
||||
return None
|
||||
|
||||
output_parser = self.get_output_parser(ai_field)
|
||||
try:
|
||||
parsed_output = output_parser.parse(output)
|
||||
except OutputParserException:
|
||||
return None
|
||||
select_option_id = int(parsed_output.name.split("_")[1])
|
||||
try:
|
||||
return next(
|
||||
o for o in ai_field.select_options.all() if o.id == select_option_id
|
||||
)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def prepare_data_sync_value(self, value, field, metadata):
|
||||
try:
|
||||
# The metadata contains a mapping of the select options where the key is the
|
||||
# old ID and the value is the new ID. For some reason the key is converted
|
||||
# to a string when moved into the JSON field.
|
||||
return int(metadata["select_options_mapping"][str(value)])
|
||||
except (KeyError, TypeError):
|
||||
return None
|
|
@ -1,5 +1,4 @@
|
|||
from baserow.contrib.database.fields.field_converters import RecreateFieldConverter
|
||||
from baserow.contrib.database.fields.models import LongTextField, TextField
|
||||
|
||||
from .models import AIField
|
||||
|
||||
|
@ -10,5 +9,5 @@ class AIFieldConverter(RecreateFieldConverter):
|
|||
def is_applicable(self, from_model, from_field, to_field):
|
||||
from_ai = isinstance(from_field, AIField)
|
||||
to_ai = isinstance(to_field, AIField)
|
||||
to_text_fields = isinstance(to_field, (TextField, LongTextField))
|
||||
return from_ai and not (to_text_fields or to_ai) or not from_ai and to_ai
|
||||
# If any field converts to the AI field, then we want to recreate the field
|
||||
return not from_ai and to_ai
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.db import IntegrityError, models
|
||||
from django.db.models import Value
|
||||
from django.db import IntegrityError
|
||||
from django.utils.functional import lazy
|
||||
|
||||
from baserow_premium.api.fields.exceptions import (
|
||||
ERROR_GENERATIVE_AI_DOES_NOT_SUPPORT_FILE_FIELD,
|
||||
|
@ -15,15 +15,13 @@ from baserow.api.generative_ai.errors import (
|
|||
ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE,
|
||||
)
|
||||
from baserow.contrib.database.api.fields.errors import ERROR_FIELD_DOES_NOT_EXIST
|
||||
from baserow.contrib.database.fields.field_filters import (
|
||||
contains_filter,
|
||||
contains_word_filter,
|
||||
from baserow.contrib.database.fields.field_types import (
|
||||
CollationSortMixin,
|
||||
SelectOptionBaseFieldType,
|
||||
)
|
||||
from baserow.contrib.database.fields.field_types import CollationSortMixin, TextField
|
||||
from baserow.contrib.database.fields.models import Field
|
||||
from baserow.contrib.database.fields.registries import FieldType
|
||||
from baserow.contrib.database.formula import BaserowFormulaTextType, BaserowFormulaType
|
||||
from baserow.core.db import collate_expression
|
||||
from baserow.contrib.database.fields.registries import field_type_registry
|
||||
from baserow.contrib.database.formula import BaserowFormulaType
|
||||
from baserow.core.formula.serializers import FormulaSerializerField
|
||||
from baserow.core.generative_ai.exceptions import (
|
||||
GenerativeAITypeDoesNotExist,
|
||||
|
@ -35,6 +33,7 @@ from baserow.core.generative_ai.registries import (
|
|||
)
|
||||
|
||||
from .models import AIField
|
||||
from .registries import ai_field_output_registry
|
||||
from .visitors import replace_field_id_references
|
||||
|
||||
User = get_user_model()
|
||||
|
@ -43,7 +42,7 @@ if TYPE_CHECKING:
|
|||
from baserow.contrib.database.table.models import GeneratedTableModel
|
||||
|
||||
|
||||
class AIFieldType(CollationSortMixin, FieldType):
|
||||
class AIFieldType(CollationSortMixin, SelectOptionBaseFieldType):
|
||||
"""
|
||||
The AI field can automatically query a generative AI model based on the provided
|
||||
prompt. It's possible to reference other fields to generate a unique output.
|
||||
|
@ -53,21 +52,29 @@ class AIFieldType(CollationSortMixin, FieldType):
|
|||
model_class = AIField
|
||||
can_be_in_form_view = False
|
||||
keep_data_on_duplication = True
|
||||
allowed_fields = [
|
||||
allowed_fields = SelectOptionBaseFieldType.allowed_fields + [
|
||||
"ai_generative_ai_type",
|
||||
"ai_generative_ai_model",
|
||||
"ai_output_type",
|
||||
"ai_temperature",
|
||||
"ai_prompt",
|
||||
"ai_file_field_id",
|
||||
]
|
||||
serializer_field_names = [
|
||||
serializer_field_names = SelectOptionBaseFieldType.allowed_fields + [
|
||||
"ai_generative_ai_type",
|
||||
"ai_generative_ai_model",
|
||||
"ai_output_type",
|
||||
"ai_temperature",
|
||||
"ai_prompt",
|
||||
"ai_file_field_id",
|
||||
]
|
||||
serializer_field_overrides = {
|
||||
"ai_output_type": serializers.ChoiceField(
|
||||
required=False,
|
||||
choices=lazy(ai_field_output_registry.get_types, list)(),
|
||||
help_text="The desired output type of the field. It will try to force the "
|
||||
"response of the prompt to match it.",
|
||||
),
|
||||
"ai_temperature": serializers.FloatField(
|
||||
required=False,
|
||||
allow_null=True,
|
||||
|
@ -89,6 +96,7 @@ class AIFieldType(CollationSortMixin, FieldType):
|
|||
allow_null=True,
|
||||
default=None,
|
||||
),
|
||||
**SelectOptionBaseFieldType.serializer_field_overrides,
|
||||
}
|
||||
api_exceptions_map = {
|
||||
GenerativeAITypeDoesNotExist: ERROR_GENERATIVE_AI_DOES_NOT_EXIST,
|
||||
|
@ -96,52 +104,173 @@ class AIFieldType(CollationSortMixin, FieldType):
|
|||
GenerativeAITypeDoesNotSupportFileField: ERROR_GENERATIVE_AI_DOES_NOT_SUPPORT_FILE_FIELD,
|
||||
IntegrityError: ERROR_FIELD_DOES_NOT_EXIST,
|
||||
}
|
||||
can_get_unique_values = False
|
||||
can_get_unique_values = True
|
||||
can_have_select_options = True
|
||||
|
||||
def get_baserow_field_type(self, instance):
|
||||
output_type = ai_field_output_registry.get(instance.ai_output_type)
|
||||
baserow_field_type = field_type_registry.get_by_type(
|
||||
output_type.baserow_field_type
|
||||
)
|
||||
return baserow_field_type
|
||||
|
||||
def get_serializer_field(self, instance, **kwargs):
|
||||
required = kwargs.get("required", False)
|
||||
return serializers.CharField(
|
||||
**{
|
||||
"required": required,
|
||||
"allow_null": not required,
|
||||
"allow_blank": not required,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
kwargs["read_only"] = True
|
||||
baserow_field_type = self.get_baserow_field_type(instance)
|
||||
return baserow_field_type.get_serializer_field(instance, **kwargs)
|
||||
|
||||
def get_response_serializer_field(self, instance, **kwargs):
|
||||
baserow_field_type = self.get_baserow_field_type(instance)
|
||||
return baserow_field_type.get_response_serializer_field(instance, **kwargs)
|
||||
|
||||
def get_model_field(self, instance, **kwargs):
|
||||
return models.TextField(null=True, **kwargs)
|
||||
baserow_field_type = self.get_baserow_field_type(instance)
|
||||
return baserow_field_type.get_model_field(instance, **kwargs)
|
||||
|
||||
def get_serializer_help_text(self, instance):
|
||||
return (
|
||||
"Holds a text value that is generated by a generative AI model using a "
|
||||
"Holds a value that is generated by a generative AI model using a "
|
||||
"dynamic prompt."
|
||||
)
|
||||
|
||||
def random_value(self, instance, fake, cache):
|
||||
return fake.name()
|
||||
baserow_field_type = self.get_baserow_field_type(instance)
|
||||
return baserow_field_type.random_value(instance, fake, cache)
|
||||
|
||||
def to_baserow_formula_type(self, field) -> BaserowFormulaType:
|
||||
return BaserowFormulaTextType(nullable=True)
|
||||
|
||||
def from_baserow_formula_type(
|
||||
self, formula_type: BaserowFormulaTextType
|
||||
) -> TextField:
|
||||
return TextField()
|
||||
baserow_field_type = self.get_baserow_field_type(field)
|
||||
return baserow_field_type.to_baserow_formula_type(field)
|
||||
|
||||
def get_value_for_filter(self, row: "GeneratedTableModel", field: Field) -> any:
|
||||
value = getattr(row, field.db_column)
|
||||
return collate_expression(Value(value))
|
||||
baserow_field_type = self.get_baserow_field_type(field)
|
||||
return baserow_field_type.get_value_for_filter(row, field)
|
||||
|
||||
def contains_query(self, *args):
|
||||
return contains_filter(*args)
|
||||
def get_alter_column_prepare_old_value(self, connection, from_field, to_field):
|
||||
baserow_field_type = self.get_baserow_field_type(from_field)
|
||||
return baserow_field_type.get_alter_column_prepare_old_value(
|
||||
connection, from_field, to_field
|
||||
)
|
||||
|
||||
def contains_word_query(self, *args):
|
||||
return contains_word_filter(*args)
|
||||
def get_alter_column_prepare_new_value(self, connection, from_field, to_field):
|
||||
baserow_field_type = self.get_baserow_field_type(to_field)
|
||||
return baserow_field_type.get_alter_column_prepare_new_value(
|
||||
connection, from_field, to_field
|
||||
)
|
||||
|
||||
def contains_query(self, field_name, value, model_field, field):
|
||||
baserow_field_type = self.get_baserow_field_type(field)
|
||||
return baserow_field_type.contains_query(field_name, value, model_field, field)
|
||||
|
||||
def contains_word_query(self, field_name, value, model_field, field):
|
||||
baserow_field_type = self.get_baserow_field_type(field)
|
||||
return baserow_field_type.contains_word_query(
|
||||
field_name, value, model_field, field
|
||||
)
|
||||
|
||||
def check_can_order_by(self, field: Field) -> bool:
|
||||
baserow_field_type = self.get_baserow_field_type(field)
|
||||
return baserow_field_type.check_can_order_by(field)
|
||||
|
||||
def check_can_group_by(self, field: Field) -> bool:
|
||||
baserow_field_type = self.get_baserow_field_type(field)
|
||||
return baserow_field_type.check_can_group_by(field)
|
||||
|
||||
def get_search_expression(self, field: Field, queryset):
|
||||
baserow_field_type = self.get_baserow_field_type(field)
|
||||
return baserow_field_type.get_search_expression(field, queryset)
|
||||
|
||||
def is_searchable(self, field):
|
||||
baserow_field_type = self.get_baserow_field_type(field)
|
||||
return baserow_field_type.is_searchable(field)
|
||||
|
||||
def enhance_queryset(self, queryset, field, name, **kwargs):
|
||||
baserow_field_type = self.get_baserow_field_type(field)
|
||||
return baserow_field_type.enhance_queryset(queryset, field, name)
|
||||
|
||||
def get_order(self, field, field_name, order_direction):
|
||||
baserow_field_type = self.get_baserow_field_type(field)
|
||||
return baserow_field_type.get_order(field, field_name, order_direction)
|
||||
|
||||
def serialize_to_input_value(self, field, value):
|
||||
baserow_field_type = self.get_baserow_field_type(field)
|
||||
return baserow_field_type.serialize_to_input_value(field, value)
|
||||
|
||||
def valid_for_bulk_update(self, field):
|
||||
baserow_field_type = self.get_baserow_field_type(field)
|
||||
return baserow_field_type.valid_for_bulk_update(field)
|
||||
|
||||
def prepare_value_for_db(self, instance, value):
|
||||
baserow_field_type = self.get_baserow_field_type(instance)
|
||||
return baserow_field_type.prepare_value_for_db(instance, value)
|
||||
|
||||
def prepare_value_for_db_in_bulk(
|
||||
self, instance, values_by_row, continue_on_error=False
|
||||
):
|
||||
baserow_field_type = self.get_baserow_field_type(instance)
|
||||
return baserow_field_type.prepare_value_for_db_in_bulk(
|
||||
instance, values_by_row, continue_on_error
|
||||
)
|
||||
|
||||
def get_group_by_serializer_field(self, field, **kwargs):
|
||||
baserow_field_type = self.get_baserow_field_type(field)
|
||||
return baserow_field_type.get_group_by_serializer_field(field, **kwargs)
|
||||
|
||||
def get_group_by_field_unique_value(self, field, field_name, value):
|
||||
baserow_field_type = self.get_baserow_field_type(field)
|
||||
return baserow_field_type.get_group_by_field_unique_value(
|
||||
field, field_name, value
|
||||
)
|
||||
|
||||
def get_group_by_field_filters_and_annotations(
|
||||
self, field, field_name, base_queryset, value
|
||||
):
|
||||
baserow_field_type = self.get_baserow_field_type(field)
|
||||
return baserow_field_type.get_group_by_field_filters_and_annotations(
|
||||
field, field_name, base_queryset, value
|
||||
)
|
||||
|
||||
def get_export_serialized_value(
|
||||
self,
|
||||
row,
|
||||
field_name,
|
||||
cache,
|
||||
files_zip=None,
|
||||
storage=None,
|
||||
):
|
||||
field_object = row.get_field_object(field_name)
|
||||
baserow_field_type = self.get_baserow_field_type(field_object["field"])
|
||||
return baserow_field_type.get_export_serialized_value(
|
||||
row, field_name, cache, files_zip, storage
|
||||
)
|
||||
|
||||
def set_import_serialized_value(
|
||||
self,
|
||||
row,
|
||||
field_name,
|
||||
value,
|
||||
id_mapping,
|
||||
cache,
|
||||
files_zip=None,
|
||||
storage=None,
|
||||
):
|
||||
field_object = row.get_field_object(field_name)
|
||||
baserow_field_type = self.get_baserow_field_type(field_object["field"])
|
||||
return baserow_field_type.set_import_serialized_value(
|
||||
row, field_name, value, id_mapping, cache, files_zip, storage
|
||||
)
|
||||
|
||||
def get_export_value(self, value, field_object, rich_value=False):
|
||||
baserow_field_type = self.get_baserow_field_type(field_object["field"])
|
||||
return baserow_field_type.get_export_value(value, field_object, rich_value)
|
||||
|
||||
def get_human_readable_value(self, value, field_object):
|
||||
baserow_field_type = self.get_baserow_field_type(field_object["field"])
|
||||
return baserow_field_type.get_human_readable_value(value, field_object)
|
||||
|
||||
def _validate_field_kwargs(
|
||||
self, ai_type, model_type, ai_file_field_id, workspace=None
|
||||
self, ai_output_type, ai_type, model_type, ai_file_field_id, workspace=None
|
||||
):
|
||||
ai_field_output_registry.get(ai_output_type)
|
||||
ai_type = generative_ai_model_type_registry.get(ai_type)
|
||||
models = ai_type.get_enabled_models(workspace=workspace)
|
||||
if model_type not in models:
|
||||
|
@ -154,12 +283,19 @@ class AIFieldType(CollationSortMixin, FieldType):
|
|||
def before_create(
|
||||
self, table, primary, allowed_field_values, order, user, field_kwargs
|
||||
):
|
||||
ai_output_type = field_kwargs.get(
|
||||
"ai_output_type", AIField._meta.get_field("ai_output_type").default
|
||||
)
|
||||
ai_type = field_kwargs.get("ai_generative_ai_type", None)
|
||||
model_type = field_kwargs.get("ai_generative_ai_model", None)
|
||||
ai_file_field_id = field_kwargs.get("ai_file_field_id", None)
|
||||
workspace = table.database.workspace
|
||||
self._validate_field_kwargs(
|
||||
ai_type, model_type, ai_file_field_id, workspace=workspace
|
||||
ai_output_type, ai_type, model_type, ai_file_field_id, workspace=workspace
|
||||
)
|
||||
|
||||
return super().before_create(
|
||||
table, primary, allowed_field_values, order, user, field_kwargs
|
||||
)
|
||||
|
||||
def before_update(self, from_field, to_field_values, user, field_kwargs):
|
||||
|
@ -167,6 +303,11 @@ class AIFieldType(CollationSortMixin, FieldType):
|
|||
if isinstance(from_field, AIField):
|
||||
update_field = from_field
|
||||
|
||||
ai_output_type = (
|
||||
field_kwargs.get("ai_output_type", None)
|
||||
or getattr(update_field, "ai_output_type", None)
|
||||
or AIField._meta.get_field("ai_output_type").default
|
||||
)
|
||||
ai_type = field_kwargs.get("ai_generative_ai_type", None) or getattr(
|
||||
update_field, "ai_generative_ai_type", None
|
||||
)
|
||||
|
@ -179,9 +320,11 @@ class AIFieldType(CollationSortMixin, FieldType):
|
|||
ai_file_field_id = getattr(update_field, "ai_file_field_id", None)
|
||||
workspace = from_field.table.database.workspace
|
||||
self._validate_field_kwargs(
|
||||
ai_type, model_type, ai_file_field_id, workspace=workspace
|
||||
ai_output_type, ai_type, model_type, ai_file_field_id, workspace=workspace
|
||||
)
|
||||
|
||||
return super().before_update(from_field, to_field_values, user, field_kwargs)
|
||||
|
||||
def after_import_serialized(
|
||||
self,
|
||||
field: AIField,
|
||||
|
@ -209,3 +352,17 @@ class AIFieldType(CollationSortMixin, FieldType):
|
|||
|
||||
if save:
|
||||
field.save()
|
||||
|
||||
def should_backup_field_data_for_same_type_update(
|
||||
self, old_field, new_field_attrs
|
||||
) -> bool:
|
||||
backup = super().should_backup_field_data_for_same_type_update(
|
||||
old_field, new_field_attrs
|
||||
)
|
||||
# Backup the field if the output type changes because
|
||||
ai_output_changed = (
|
||||
"ai_output_type" in new_field_attrs
|
||||
and new_field_attrs["ai_output_type"]
|
||||
and new_field_attrs["ai_output_type"] != old_field.ai_output_type
|
||||
)
|
||||
return backup or ai_output_changed
|
||||
|
|
|
@ -3,12 +3,34 @@ from django.db import models
|
|||
from baserow.contrib.database.fields.models import Field
|
||||
from baserow.core.formula.field import FormulaField as ModelFormulaField
|
||||
|
||||
from .ai_field_output_types import TextAIFieldOutputType
|
||||
from .registries import ai_field_output_registry
|
||||
|
||||
|
||||
class AIField(Field):
|
||||
ai_generative_ai_type = models.CharField(max_length=32, null=True)
|
||||
ai_generative_ai_model = models.CharField(max_length=32, null=True)
|
||||
ai_output_type = models.CharField(
|
||||
max_length=32,
|
||||
db_default=TextAIFieldOutputType.type,
|
||||
default=TextAIFieldOutputType.type,
|
||||
)
|
||||
ai_temperature = models.FloatField(null=True)
|
||||
ai_prompt = ModelFormulaField(default="")
|
||||
ai_file_field = models.ForeignKey(
|
||||
Field, null=True, on_delete=models.SET_NULL, related_name="ai_field"
|
||||
)
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""
|
||||
When a property is called on the field object, it tries to return the default
|
||||
value of the field object related to the `ai_output_type` `model_class`. This
|
||||
will make it more compatible with the check functions like `check_can_group_by`.
|
||||
"""
|
||||
|
||||
try:
|
||||
ai_output_type = ai_field_output_registry.get(self.ai_output_type)
|
||||
output_field = ai_output_type.baserow_field_type.model_class
|
||||
return output_field._meta.get_field(name).default
|
||||
except Exception:
|
||||
super().__getattr__(name)
|
||||
|
|
70
premium/backend/src/baserow_premium/fields/registries.py
Normal file
70
premium/backend/src/baserow_premium/fields/registries.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
import abc
|
||||
import typing
|
||||
from typing import Any
|
||||
|
||||
from baserow.contrib.database.fields.models import Field
|
||||
from baserow.core.registry import Instance, Registry
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from baserow_premium.fields.models import AIField
|
||||
|
||||
|
||||
class AIFieldOutputType(abc.ABC, Instance):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def baserow_field_type(self) -> str:
|
||||
"""
|
||||
The Baserow field type that corresponds to this AI output type and should be
|
||||
used to do various Baserow operations like filtering, sorting, etc.
|
||||
"""
|
||||
|
||||
def format_prompt(self, prompt: str, ai_field: "AIField"):
|
||||
"""
|
||||
Hook that can be used to change and format the provided prompt for this output
|
||||
type. It accepts the original already resolved prompt and should return the
|
||||
updated one.
|
||||
|
||||
It can be used to include the format instructions of an output parser, for
|
||||
example.
|
||||
|
||||
:param prompt: The resolved prompt provided by the user. This already contains
|
||||
the resolved variables.
|
||||
:param ai_field: The AI field related to the output type.
|
||||
:return: Should return the formatted prompt. This can include additional
|
||||
information that can change the outcome of the prompt.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def parse_output(self, output: str, ai_field: "AIField"):
|
||||
"""
|
||||
Hook that can be used to parse the output of the generative AI prompt. If an
|
||||
output parser formatting instructions are added in `format_prompt`, then this
|
||||
hook can be used to parse it.
|
||||
|
||||
:param output: The text output of the generative AI.
|
||||
:param ai_field: The AI field related to the output type.
|
||||
:return: Should return the parsed output.
|
||||
"""
|
||||
|
||||
return output
|
||||
|
||||
def prepare_data_sync_value(self, value: Any, field: Field, metadata: dict) -> Any:
|
||||
"""
|
||||
Hook that's called when preparing the value in the local Baserow data sync.
|
||||
It's for example used to map the value of the single select option.
|
||||
|
||||
:param value: The original value.
|
||||
:param field: The field that's synced.
|
||||
:param metadata: The metadata related to the datasync property.
|
||||
:return: The updated value.
|
||||
"""
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class AIFieldOutputRegistry(Registry):
|
||||
name = "ai_field_output"
|
||||
|
||||
|
||||
ai_field_output_registry: AIFieldOutputRegistry = AIFieldOutputRegistry()
|
|
@ -20,6 +20,7 @@ from baserow.core.handler import CoreHandler
|
|||
from baserow.core.user.handler import User
|
||||
|
||||
from .models import AIField
|
||||
from .registries import ai_field_output_registry
|
||||
|
||||
|
||||
@app.task(bind=True, queue="export")
|
||||
|
@ -28,9 +29,9 @@ def generate_ai_values_for_rows(self, user_id: int, field_id: int, row_ids: list
|
|||
|
||||
ai_field = FieldHandler().get_field(
|
||||
field_id,
|
||||
base_queryset=AIField.objects.all().select_related(
|
||||
"table__database__workspace"
|
||||
),
|
||||
base_queryset=AIField.objects.all()
|
||||
.select_related("table__database__workspace")
|
||||
.prefetch_related("select_options"),
|
||||
)
|
||||
table = ai_field.table
|
||||
workspace = table.database.workspace
|
||||
|
@ -71,6 +72,8 @@ def generate_ai_values_for_rows(self, user_id: int, field_id: int, row_ids: list
|
|||
)
|
||||
raise exc
|
||||
|
||||
ai_output_type = ai_field_output_registry.get(ai_field.ai_output_type)
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
context = HumanReadableRowContext(row, exclude_field_ids=[ai_field.id])
|
||||
message = str(
|
||||
|
@ -79,6 +82,11 @@ def generate_ai_values_for_rows(self, user_id: int, field_id: int, row_ids: list
|
|||
)
|
||||
)
|
||||
|
||||
# The AI output type should be able to format the prompt because it can add
|
||||
# additional instructions to it. The choice output type for example adds
|
||||
# additional prompt trying to force the out, for example.
|
||||
message = ai_output_type.format_prompt(message, ai_field)
|
||||
|
||||
try:
|
||||
if ai_field.ai_file_field_id is not None and isinstance(
|
||||
generative_ai_model_type, GenerativeAIWithFilesModelType
|
||||
|
@ -105,6 +113,12 @@ def generate_ai_values_for_rows(self, user_id: int, field_id: int, row_ids: list
|
|||
workspace=workspace,
|
||||
temperature=ai_field.ai_temperature,
|
||||
)
|
||||
|
||||
# Because the AI output type can change the prompt to try to force the
|
||||
# output a certain way, then it should give the opportunity to parse the
|
||||
# output when it's given. With the choice output type, it will try to match
|
||||
# it to a `SelectOption`, for example.
|
||||
value = ai_output_type.parse_output(value, ai_field)
|
||||
except Exception as exc:
|
||||
# If the prompt fails once, we should not continue with the other rows.
|
||||
rows_ai_values_generation_error.send(
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# Generated by Django 5.0.9 on 2024-10-28 13:30
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("baserow_premium", "0022_aifield_ai_temperature"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="aifield",
|
||||
name="ai_output_type",
|
||||
field=models.CharField(db_default="text", default="text", max_length=32),
|
||||
),
|
||||
]
|
|
@ -111,7 +111,8 @@ def test_can_export_every_interesting_different_field_to_json(
|
|||
"uuid": "00000000-0000-4000-8000-000000000001",
|
||||
"autonumber": 1,
|
||||
"password": "",
|
||||
"ai": ""
|
||||
"ai": "",
|
||||
"ai_choice": ""
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
|
@ -230,7 +231,8 @@ def test_can_export_every_interesting_different_field_to_json(
|
|||
"uuid": "00000000-0000-4000-8000-000000000002",
|
||||
"autonumber": 2,
|
||||
"password": true,
|
||||
"ai": "I'm an AI."
|
||||
"ai": "I'm an AI.",
|
||||
"ai_choice": "Object"
|
||||
}
|
||||
]
|
||||
"""
|
||||
|
@ -393,6 +395,7 @@ def test_can_export_every_interesting_different_field_to_xml(
|
|||
<autonumber>1</autonumber>
|
||||
<password/>
|
||||
<ai/>
|
||||
<ai-choice/>
|
||||
</row>
|
||||
<row>
|
||||
<id>2</id>
|
||||
|
@ -512,6 +515,7 @@ def test_can_export_every_interesting_different_field_to_xml(
|
|||
<autonumber>2</autonumber>
|
||||
<password>true</password>
|
||||
<ai>I'm an AI.</ai>
|
||||
<ai-choice>Object</ai-choice>
|
||||
</row>
|
||||
</rows>
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_dynamic_get_attr(premium_data_fixture):
|
||||
field = premium_data_fixture.create_ai_field(ai_output_type="text")
|
||||
assert field.long_text_enable_rich_text is False
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
field.non_existing_property
|
|
@ -0,0 +1,109 @@
|
|||
import enum
|
||||
|
||||
import pytest
|
||||
from baserow_premium.fields.ai_field_output_types import StrictEnumOutputParser
|
||||
from baserow_premium.fields.tasks import generate_ai_values_for_rows
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from baserow.core.generative_ai.registries import (
|
||||
GenerativeAIModelType,
|
||||
generative_ai_model_type_registry,
|
||||
)
|
||||
|
||||
|
||||
def test_strict_enum_output_parser():
|
||||
choices = enum.Enum(
|
||||
"Choices",
|
||||
{
|
||||
"OPTION_1": "Object",
|
||||
"OPTION_2": "Animal",
|
||||
"OPTION_3": "Human",
|
||||
"OPTION_4": "A,B,C",
|
||||
},
|
||||
)
|
||||
output_parser = StrictEnumOutputParser(enum=choices)
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
prompt = "What is a motorcycle?"
|
||||
prompt = PromptTemplate(
|
||||
template=prompt + "\n\n{format_instructions}",
|
||||
input_variables=[],
|
||||
partial_variables={"format_instructions": format_instructions},
|
||||
)
|
||||
message = prompt.format()
|
||||
|
||||
assert '["Object", "Animal", "Human", "A,B,C"]' in message
|
||||
|
||||
assert output_parser.parse("Object") == choices.OPTION_1
|
||||
assert output_parser.parse("Animal") == choices.OPTION_2
|
||||
assert output_parser.parse("Human") == choices.OPTION_3
|
||||
assert output_parser.parse("A,B,C") == choices.OPTION_4
|
||||
|
||||
assert output_parser.parse(" Object ") == choices.OPTION_1
|
||||
assert output_parser.parse("'Object'") == choices.OPTION_1
|
||||
assert output_parser.parse("'A'") == choices.OPTION_4
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_choice_output_type(premium_data_fixture, api_client):
|
||||
class TestAIChoiceOutputTypeGenerativeAIModelType(GenerativeAIModelType):
|
||||
type = "test_ai_choice_ouput_type"
|
||||
i = 0
|
||||
|
||||
def is_enabled(self, workspace=None):
|
||||
return True
|
||||
|
||||
def get_enabled_models(self, workspace=None):
|
||||
return ["test_1"]
|
||||
|
||||
def prompt(self, model, prompt, workspace=None, temperature=None):
|
||||
self.i += 1
|
||||
if self.i == 1:
|
||||
# Existing option should be matches based on the string.
|
||||
return "Object"
|
||||
else:
|
||||
return "Else"
|
||||
|
||||
def get_settings_serializer(self):
|
||||
return None
|
||||
|
||||
generative_ai_model_type_registry.register(
|
||||
TestAIChoiceOutputTypeGenerativeAIModelType()
|
||||
)
|
||||
|
||||
premium_data_fixture.register_fake_generate_ai_type()
|
||||
user = premium_data_fixture.create_user()
|
||||
database = premium_data_fixture.create_database_application(
|
||||
user=user, name="Placeholder"
|
||||
)
|
||||
table = premium_data_fixture.create_database_table(
|
||||
name="Example", database=database
|
||||
)
|
||||
field = premium_data_fixture.create_ai_field(
|
||||
table=table,
|
||||
order=0,
|
||||
name="ai",
|
||||
ai_output_type="choice",
|
||||
ai_generative_ai_type="test_ai_choice_ouput_type",
|
||||
ai_generative_ai_model="test_1",
|
||||
ai_prompt="'Option'",
|
||||
)
|
||||
option_1 = premium_data_fixture.create_select_option(
|
||||
field=field, value="Object", color="red"
|
||||
)
|
||||
option_2 = premium_data_fixture.create_select_option(
|
||||
field=field, value="Else", color="red"
|
||||
)
|
||||
premium_data_fixture.create_select_option(field=field, value="Animal", color="blue")
|
||||
|
||||
model = table.get_model()
|
||||
row_1 = model.objects.create()
|
||||
row_2 = model.objects.create()
|
||||
|
||||
generate_ai_values_for_rows(user.id, field.id, [row_1.id, row_2.id])
|
||||
|
||||
row_1.refresh_from_db()
|
||||
row_2.refresh_from_db()
|
||||
|
||||
assert getattr(row_1, f"field_{field.id}").id == option_1.id
|
||||
assert getattr(row_2, f"field_{field.id}").id == option_2.id
|
|
@ -1,10 +1,12 @@
|
|||
from django.shortcuts import reverse
|
||||
|
||||
import pytest
|
||||
from baserow_premium.fields.field_types import AIFieldType
|
||||
from baserow_premium.fields.models import AIField
|
||||
from rest_framework.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND
|
||||
|
||||
from baserow.contrib.database.fields.handler import FieldHandler
|
||||
from baserow.contrib.database.fields.registries import field_type_registry
|
||||
from baserow.contrib.database.table.handler import TableHandler
|
||||
from baserow.core.db import specific_iterator
|
||||
|
||||
|
@ -28,6 +30,7 @@ def test_create_ai_field_type(premium_data_fixture):
|
|||
ai_prompt="'Who are you?'",
|
||||
)
|
||||
|
||||
assert ai_field.ai_output_type == "text" # default value
|
||||
assert ai_field.ai_generative_ai_type == "test_generative_ai"
|
||||
assert ai_field.ai_generative_ai_model == "test_1"
|
||||
assert ai_field.ai_prompt == "'Who are you?'"
|
||||
|
@ -51,6 +54,7 @@ def test_update_ai_field_type(premium_data_fixture):
|
|||
ai_prompt="'Who are you?'",
|
||||
)
|
||||
|
||||
assert ai_field.ai_output_type == "text" # default value
|
||||
assert ai_field.ai_generative_ai_type == "test_generative_ai"
|
||||
assert ai_field.ai_generative_ai_model == "test_1"
|
||||
assert ai_field.ai_prompt == "'Who are you?'"
|
||||
|
@ -98,6 +102,68 @@ def test_create_ai_field_type_via_api(premium_data_fixture, api_client):
|
|||
)
|
||||
response_json = response.json()
|
||||
assert response.status_code == HTTP_200_OK
|
||||
assert response_json["ai_output_type"] == "text"
|
||||
assert response_json["ai_generative_ai_type"] == "test_generative_ai"
|
||||
assert response_json["ai_generative_ai_model"] == "test_1"
|
||||
assert response_json["ai_prompt"] == "'Who are you?'"
|
||||
assert response_json["ai_temperature"] is None
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_create_ai_field_type_via_api_with_non_existing_ai_output_type(
|
||||
premium_data_fixture, api_client
|
||||
):
|
||||
user, token = premium_data_fixture.create_user_and_token()
|
||||
table = premium_data_fixture.create_database_table(user=user)
|
||||
premium_data_fixture.register_fake_generate_ai_type()
|
||||
premium_data_fixture.create_text_field(table=table, order=1, name="name")
|
||||
|
||||
response = api_client.post(
|
||||
reverse("api:database:fields:list", kwargs={"table_id": table.id}),
|
||||
{
|
||||
"name": "Test 1",
|
||||
"type": "ai",
|
||||
"ai_output_type": "DOES_NOT_EXIST",
|
||||
"ai_generative_ai_type": "test_generative_ai",
|
||||
"ai_generative_ai_model": "test_1",
|
||||
"ai_prompt": "'Who are you?'",
|
||||
},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
response_json = response.json()
|
||||
assert response.status_code == HTTP_400_BAD_REQUEST
|
||||
assert response_json["error"] == "ERROR_REQUEST_BODY_VALIDATION"
|
||||
assert response_json["detail"]["ai_output_type"][0]["code"] == "invalid_choice"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_create_ai_field_type_via_api_with_ai_output_type(
|
||||
premium_data_fixture, api_client
|
||||
):
|
||||
user, token = premium_data_fixture.create_user_and_token()
|
||||
table = premium_data_fixture.create_database_table(user=user)
|
||||
premium_data_fixture.register_fake_generate_ai_type()
|
||||
premium_data_fixture.create_text_field(table=table, order=1, name="name")
|
||||
|
||||
response = api_client.post(
|
||||
reverse("api:database:fields:list", kwargs={"table_id": table.id}),
|
||||
{
|
||||
"name": "Test 1",
|
||||
"type": "ai",
|
||||
"ai_output_type": "text",
|
||||
"ai_generative_ai_type": "test_generative_ai",
|
||||
"ai_generative_ai_model": "test_1",
|
||||
"ai_prompt": "'Who are you?'",
|
||||
},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
response_json = response.json()
|
||||
assert response.status_code == HTTP_200_OK
|
||||
assert response_json["ai_output_type"] == "text"
|
||||
assert response_json["ai_generative_ai_type"] == "test_generative_ai"
|
||||
assert response_json["ai_generative_ai_model"] == "test_1"
|
||||
assert response_json["ai_prompt"] == "'Who are you?'"
|
||||
|
@ -191,9 +257,6 @@ def test_update_ai_field_temperature_none_via_api(premium_data_fixture, api_clie
|
|||
field = premium_data_fixture.create_ai_field(
|
||||
table=table, order=1, name="name", ai_temperature=0.7
|
||||
)
|
||||
file_field = premium_data_fixture.create_file_field(
|
||||
table=table, order=2, name="file"
|
||||
)
|
||||
|
||||
response = api_client.patch(
|
||||
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
|
||||
|
@ -209,6 +272,121 @@ def test_update_ai_field_temperature_none_via_api(premium_data_fixture, api_clie
|
|||
assert response.json()["ai_temperature"] is None
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_update_ai_field_via_api_invalid_output_type(premium_data_fixture, api_client):
|
||||
user, token = premium_data_fixture.create_user_and_token()
|
||||
table = premium_data_fixture.create_database_table(user=user)
|
||||
premium_data_fixture.register_fake_generate_ai_type()
|
||||
field = premium_data_fixture.create_ai_field(
|
||||
table=table, order=1, name="name", ai_temperature=0.7
|
||||
)
|
||||
|
||||
response = api_client.patch(
|
||||
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
|
||||
{
|
||||
"ai_output_type": "INVALID_CHOICE",
|
||||
"ai_generative_ai_type": "test_generative_ai_with_files",
|
||||
"ai_generative_ai_model": "test_1",
|
||||
"ai_temperature": None,
|
||||
},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
response_json = response.json()
|
||||
assert response.status_code == HTTP_400_BAD_REQUEST
|
||||
assert response_json["error"] == "ERROR_REQUEST_BODY_VALIDATION"
|
||||
assert response_json["detail"]["ai_output_type"][0]["code"] == "invalid_choice"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_update_ai_field_via_api_valid_output_type(premium_data_fixture, api_client):
|
||||
user, token = premium_data_fixture.create_user_and_token()
|
||||
table = premium_data_fixture.create_database_table(user=user)
|
||||
premium_data_fixture.register_fake_generate_ai_type()
|
||||
field = premium_data_fixture.create_ai_field(
|
||||
table=table, order=1, name="name", ai_temperature=0.7
|
||||
)
|
||||
|
||||
response = api_client.patch(
|
||||
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
|
||||
{
|
||||
"ai_output_type": "text",
|
||||
"ai_generative_ai_type": "test_generative_ai",
|
||||
"ai_generative_ai_model": "test_1",
|
||||
"ai_temperature": None,
|
||||
"ai_prompt": "'Who are you?'",
|
||||
},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
response_json = response.json()
|
||||
assert response.status_code == HTTP_200_OK
|
||||
assert response_json["ai_output_type"] == "text"
|
||||
assert response_json["ai_generative_ai_type"] == "test_generative_ai"
|
||||
assert response_json["ai_generative_ai_model"] == "test_1"
|
||||
assert response_json["ai_prompt"] == "'Who are you?'"
|
||||
assert response_json["ai_temperature"] is None
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_update_to_ai_field_with_all_parameters(premium_data_fixture, api_client):
|
||||
user, token = premium_data_fixture.create_user_and_token()
|
||||
table = premium_data_fixture.create_database_table(user=user)
|
||||
premium_data_fixture.register_fake_generate_ai_type()
|
||||
field = premium_data_fixture.create_text_field(table=table, order=1, name="name")
|
||||
|
||||
response = api_client.patch(
|
||||
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
|
||||
{
|
||||
"type": "ai",
|
||||
"ai_output_type": "text",
|
||||
"ai_generative_ai_type": "test_generative_ai",
|
||||
"ai_generative_ai_model": "test_1",
|
||||
"ai_temperature": None,
|
||||
"ai_prompt": "'Who are you?'",
|
||||
},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
response_json = response.json()
|
||||
assert response.status_code == HTTP_200_OK
|
||||
assert response_json["ai_output_type"] == "text"
|
||||
assert response_json["ai_generative_ai_type"] == "test_generative_ai"
|
||||
assert response_json["ai_generative_ai_model"] == "test_1"
|
||||
assert response_json["ai_prompt"] == "'Who are you?'"
|
||||
assert response_json["ai_temperature"] is None
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_update_to_ai_field_without_parameters(premium_data_fixture, api_client):
|
||||
user, token = premium_data_fixture.create_user_and_token()
|
||||
table = premium_data_fixture.create_database_table(user=user)
|
||||
premium_data_fixture.register_fake_generate_ai_type()
|
||||
field = premium_data_fixture.create_text_field(table=table, order=1, name="name")
|
||||
|
||||
response = api_client.patch(
|
||||
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
|
||||
{
|
||||
"type": "ai",
|
||||
"ai_generative_ai_type": "test_generative_ai",
|
||||
"ai_generative_ai_model": "test_1",
|
||||
},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
response_json = response.json()
|
||||
assert response.status_code == HTTP_200_OK
|
||||
assert response_json["ai_output_type"] == "text"
|
||||
assert response_json["ai_generative_ai_type"] == "test_generative_ai"
|
||||
assert response_json["ai_generative_ai_model"] == "test_1"
|
||||
assert response_json["ai_prompt"] == ""
|
||||
assert response_json["ai_temperature"] is None
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_create_ai_field_type_via_api_invalid_formula(premium_data_fixture, api_client):
|
||||
|
@ -606,3 +784,266 @@ def test_duplicate_table_with_ai_field_broken_references(premium_data_fixture):
|
|||
duplicated_ai_field = duplicated_fields[2]
|
||||
|
||||
assert duplicated_ai_field.ai_prompt == f"concat('test:',get('fields.field_0'))"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_can_set_select_options_to_choice_ai_output_type(
|
||||
premium_data_fixture, api_client
|
||||
):
|
||||
user, token = premium_data_fixture.create_user_and_token()
|
||||
table = premium_data_fixture.create_database_table(user=user)
|
||||
premium_data_fixture.register_fake_generate_ai_type()
|
||||
premium_data_fixture.create_text_field(table=table, order=1, name="name")
|
||||
|
||||
response = api_client.post(
|
||||
reverse("api:database:fields:list", kwargs={"table_id": table.id}),
|
||||
{
|
||||
"name": "Test 1",
|
||||
"type": "ai",
|
||||
"ai_output_type": "choice",
|
||||
"ai_generative_ai_type": "test_generative_ai",
|
||||
"ai_generative_ai_model": "test_1",
|
||||
"ai_prompt": "'Who are you?'",
|
||||
"select_options": [
|
||||
{"value": "Small", "color": "red"},
|
||||
{"value": "Medium", "color": "blue"},
|
||||
{"value": "Large", "color": "green"},
|
||||
],
|
||||
},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
response_json = response.json()
|
||||
|
||||
assert response.status_code == HTTP_200_OK
|
||||
assert response_json["ai_output_type"] == "choice"
|
||||
assert response_json["ai_generative_ai_type"] == "test_generative_ai"
|
||||
assert response_json["ai_generative_ai_model"] == "test_1"
|
||||
assert response_json["ai_prompt"] == "'Who are you?'"
|
||||
assert len(response_json["select_options"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_should_backup(premium_data_fixture, api_client):
|
||||
ai_field_type = field_type_registry.get(AIFieldType.type)
|
||||
ai_field = premium_data_fixture.create_ai_field()
|
||||
file_field = premium_data_fixture.create_file_field(table=ai_field.table)
|
||||
|
||||
assert (
|
||||
ai_field_type.should_backup_field_data_for_same_type_update(
|
||||
ai_field,
|
||||
{
|
||||
"ai_generative_ai_type": "test_generative_ai_2",
|
||||
"ai_generative_ai_model": "test_model_2",
|
||||
"ai_prompt": "'New AI prompt'",
|
||||
"ai_output_type": "text", # same as before
|
||||
"ai_temperature": 1,
|
||||
"ai_file_field": file_field,
|
||||
},
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
assert (
|
||||
ai_field_type.should_backup_field_data_for_same_type_update(
|
||||
ai_field,
|
||||
{
|
||||
"ai_generative_ai_type": "test_generative_ai_2",
|
||||
"ai_generative_ai_model": "test_model_2",
|
||||
"ai_prompt": "'New AI prompt'",
|
||||
"ai_output_type": "choice", # new one
|
||||
"ai_temperature": 1,
|
||||
"ai_file_field": file_field,
|
||||
},
|
||||
)
|
||||
is True
|
||||
) # Expect to make a backup when output type changes.
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_can_convert_from_text_output_type_to_choice_output_type(
|
||||
premium_data_fixture, api_client
|
||||
):
|
||||
premium_data_fixture.register_fake_generate_ai_type()
|
||||
user = premium_data_fixture.create_user()
|
||||
database = premium_data_fixture.create_database_application(
|
||||
user=user, name="Placeholder"
|
||||
)
|
||||
table = premium_data_fixture.create_database_table(
|
||||
name="Example", database=database
|
||||
)
|
||||
field = premium_data_fixture.create_ai_field(
|
||||
table=table, order=0, name="ai", ai_output_type="text"
|
||||
)
|
||||
|
||||
model = table.get_model()
|
||||
model.objects.create(**{f"field_{field.id}": "Option 1"})
|
||||
model.objects.create(**{f"field_{field.id}": "Something else"})
|
||||
|
||||
field = FieldHandler().update_field(
|
||||
user=user,
|
||||
field=field,
|
||||
ai_output_type="choice",
|
||||
select_options=[
|
||||
{"value": "Option 1", "color": "red"},
|
||||
],
|
||||
)
|
||||
|
||||
table.refresh_from_db()
|
||||
model = table.get_model()
|
||||
rows = list(model.objects.all())
|
||||
|
||||
# Converting text ai field to choice field should try to convert the text values to
|
||||
# the new choices.
|
||||
assert getattr(rows[0], f"field_{field.id}").value == "Option 1"
|
||||
assert getattr(rows[1], f"field_{field.id}") is None
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_can_convert_from_choice_output_type_to_text_output_type(
|
||||
premium_data_fixture, api_client
|
||||
):
|
||||
premium_data_fixture.register_fake_generate_ai_type()
|
||||
user = premium_data_fixture.create_user()
|
||||
database = premium_data_fixture.create_database_application(
|
||||
user=user, name="Placeholder"
|
||||
)
|
||||
table = premium_data_fixture.create_database_table(
|
||||
name="Example", database=database
|
||||
)
|
||||
field = premium_data_fixture.create_ai_field(
|
||||
table=table, order=0, name="ai", ai_output_type="choice"
|
||||
)
|
||||
select_option = premium_data_fixture.create_select_option(
|
||||
field=field, value="Option 1", color="blue", order=0
|
||||
)
|
||||
|
||||
model = table.get_model()
|
||||
model.objects.create(**{f"field_{field.id}_id": select_option.id})
|
||||
model.objects.create(**{f"field_{field.id}": None})
|
||||
|
||||
field = FieldHandler().update_field(
|
||||
user=user,
|
||||
field=field,
|
||||
ai_output_type="text",
|
||||
)
|
||||
|
||||
table.refresh_from_db()
|
||||
model = table.get_model()
|
||||
rows = list(model.objects.all())
|
||||
|
||||
# Converting choice ai field to text ai field should try to convert the choices to
|
||||
# text values.
|
||||
assert getattr(rows[0], f"field_{field.id}") == "Option 1"
|
||||
assert getattr(rows[1], f"field_{field.id}") is None
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_can_convert_from_text_field_to_text_output_type(
|
||||
premium_data_fixture, api_client
|
||||
):
|
||||
premium_data_fixture.register_fake_generate_ai_type()
|
||||
user = premium_data_fixture.create_user()
|
||||
database = premium_data_fixture.create_database_application(
|
||||
user=user, name="Placeholder"
|
||||
)
|
||||
table = premium_data_fixture.create_database_table(
|
||||
name="Example", database=database
|
||||
)
|
||||
field = premium_data_fixture.create_text_field(table=table, order=0, name="text")
|
||||
|
||||
model = table.get_model()
|
||||
model.objects.create(**{f"field_{field.id}": "Test"})
|
||||
|
||||
field = FieldHandler().update_field(
|
||||
user=user,
|
||||
field=field,
|
||||
new_type_name="ai",
|
||||
ai_generative_ai_type="test_generative_ai",
|
||||
ai_generative_ai_model="test_1",
|
||||
ai_prompt="'test'",
|
||||
ai_output_type="text",
|
||||
)
|
||||
|
||||
table.refresh_from_db()
|
||||
model = table.get_model()
|
||||
rows = list(model.objects.all())
|
||||
|
||||
# Expect the value to be reset because we don't want to keep the existing cell
|
||||
# value when converting from any other field.
|
||||
assert getattr(rows[0], f"field_{field.id}") is None
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_can_convert_from_text_field_to_choice_output_type(
|
||||
premium_data_fixture, api_client
|
||||
):
|
||||
premium_data_fixture.register_fake_generate_ai_type()
|
||||
user = premium_data_fixture.create_user()
|
||||
database = premium_data_fixture.create_database_application(
|
||||
user=user, name="Placeholder"
|
||||
)
|
||||
table = premium_data_fixture.create_database_table(
|
||||
name="Example", database=database
|
||||
)
|
||||
field = premium_data_fixture.create_text_field(table=table, order=0, name="text")
|
||||
|
||||
model = table.get_model()
|
||||
model.objects.create(**{f"field_{field.id}": "Test"})
|
||||
|
||||
field = FieldHandler().update_field(
|
||||
user=user,
|
||||
field=field,
|
||||
new_type_name="ai",
|
||||
ai_generative_ai_type="test_generative_ai",
|
||||
ai_generative_ai_model="test_1",
|
||||
ai_prompt="'test'",
|
||||
ai_output_type="choice",
|
||||
)
|
||||
|
||||
table.refresh_from_db()
|
||||
model = table.get_model()
|
||||
rows = list(model.objects.all())
|
||||
|
||||
# Expect the value to be reset because we don't want to keep the existing cell
|
||||
# value when converting from any other field.
|
||||
assert getattr(rows[0], f"field_{field.id}") is None
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_can_convert_from_text_output_type_to_text_field(
|
||||
premium_data_fixture, api_client
|
||||
):
|
||||
premium_data_fixture.register_fake_generate_ai_type()
|
||||
user = premium_data_fixture.create_user()
|
||||
database = premium_data_fixture.create_database_application(
|
||||
user=user, name="Placeholder"
|
||||
)
|
||||
table = premium_data_fixture.create_database_table(
|
||||
name="Example", database=database
|
||||
)
|
||||
field = premium_data_fixture.create_ai_field(table=table, order=0, name="ai")
|
||||
|
||||
model = table.get_model()
|
||||
model.objects.create(**{f"field_{field.id}": "Test"})
|
||||
|
||||
field = FieldHandler().update_field(
|
||||
user=user,
|
||||
field=field,
|
||||
new_type_name="text",
|
||||
)
|
||||
|
||||
table.refresh_from_db()
|
||||
model = table.get_model()
|
||||
rows = list(model.objects.all())
|
||||
|
||||
# Converting text ai field to text field should keep the values because the text
|
||||
# field conversion is automatically used.
|
||||
assert getattr(rows[0], f"field_{field.id}") == "Test"
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
import { Registerable } from '@baserow/modules/core/registry'
|
||||
|
||||
import {
|
||||
LongTextFieldType,
|
||||
SingleSelectFieldType,
|
||||
} from '@baserow/modules/database/fieldTypes'
|
||||
import FieldSelectOptionsSubForm from '@baserow/modules/database/components/field/FieldSelectOptionsSubForm.vue'
|
||||
|
||||
export class AIFieldOutputType extends Registerable {
|
||||
/**
|
||||
* A human readable name of the AI output type. This will be shown in in the dropdown
|
||||
* where the user chosen the type.
|
||||
*/
|
||||
getName() {
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* A human-readable description of the AI output type.
|
||||
*/
|
||||
getDescription() {
|
||||
return null
|
||||
}
|
||||
|
||||
constructor(...args) {
|
||||
super(...args)
|
||||
this.type = this.getType()
|
||||
|
||||
if (this.type === null) {
|
||||
throw new Error('The type name of an admin type must be set.')
|
||||
}
|
||||
if (this.name === null) {
|
||||
throw new Error('The name of an admin type must be set.')
|
||||
}
|
||||
}
|
||||
|
||||
getBaserowFieldType() {
|
||||
throw new Error('The Baserow field type must be set on the AI output type.')
|
||||
}
|
||||
|
||||
/**
|
||||
* Can optionally return a form component that will be added to the `FieldForm` is
|
||||
* the output type is chosen.
|
||||
*/
|
||||
getFormComponent() {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export class TextAIFieldOutputType extends AIFieldOutputType {
|
||||
static getType() {
|
||||
return 'text'
|
||||
}
|
||||
|
||||
getName() {
|
||||
const { i18n } = this.app
|
||||
return i18n.t('aiOutputType.text')
|
||||
}
|
||||
|
||||
getDescription() {
|
||||
const { i18n } = this.app
|
||||
return i18n.t('aiOutputType.textDescription')
|
||||
}
|
||||
|
||||
getBaserowFieldType() {
|
||||
return this.app.$registry.get('field', LongTextFieldType.getType())
|
||||
}
|
||||
}
|
||||
|
||||
export class ChoiceAIFieldOutputType extends AIFieldOutputType {
|
||||
static getType() {
|
||||
return 'choice'
|
||||
}
|
||||
|
||||
getName() {
|
||||
const { i18n } = this.app
|
||||
return i18n.t('aiOutputType.choice')
|
||||
}
|
||||
|
||||
getDescription() {
|
||||
const { i18n } = this.app
|
||||
return i18n.t('aiOutputType.choiceDescription')
|
||||
}
|
||||
|
||||
getBaserowFieldType() {
|
||||
return this.app.$registry.get('field', SingleSelectFieldType.getType())
|
||||
}
|
||||
|
||||
getFormComponent() {
|
||||
return FieldSelectOptionsSubForm
|
||||
}
|
||||
}
|
|
@ -35,6 +35,31 @@
|
|||
/>
|
||||
</Dropdown>
|
||||
</FormGroup>
|
||||
|
||||
<FormGroup
|
||||
required
|
||||
small-label
|
||||
:label="$t('fieldAISubForm.outputType')"
|
||||
:help-icon-tooltip="$t('fieldAISubForm.outputTypeTooltip')"
|
||||
>
|
||||
<Dropdown
|
||||
v-model="values.ai_output_type"
|
||||
class="dropdown--floating"
|
||||
:fixed-items="true"
|
||||
>
|
||||
<DropdownItem
|
||||
v-for="outputType in outputTypes"
|
||||
:key="outputType.getType()"
|
||||
:name="outputType.getName()"
|
||||
:value="outputType.getType()"
|
||||
:description="outputType.getDescription()"
|
||||
/>
|
||||
</Dropdown>
|
||||
<template v-if="changedOutputType" #warning>
|
||||
{{ $t('fieldAISubForm.outputTypeChangedWarning') }}
|
||||
</template>
|
||||
</FormGroup>
|
||||
|
||||
<FormGroup
|
||||
small-label
|
||||
:label="$t('fieldAISubForm.prompt')"
|
||||
|
@ -52,6 +77,12 @@
|
|||
</div>
|
||||
<template #error> {{ $t('error.requiredField') }}</template>
|
||||
</FormGroup>
|
||||
|
||||
<component
|
||||
:is="outputType.getFormComponent()"
|
||||
ref="childForm"
|
||||
v-bind="$props"
|
||||
/>
|
||||
</div>
|
||||
<div v-else>
|
||||
<p>
|
||||
|
@ -67,6 +98,7 @@ import form from '@baserow/modules/core/mixins/form'
|
|||
import fieldSubForm from '@baserow/modules/database/mixins/fieldSubForm'
|
||||
import FormulaInputField from '@baserow/modules/core/components/formula/FormulaInputField'
|
||||
import SelectAIModelForm from '@baserow/modules/core/components/ai/SelectAIModelForm'
|
||||
import { TextAIFieldOutputType } from '@baserow_premium/aiFieldOutputTypes'
|
||||
|
||||
export default {
|
||||
name: 'FieldAISubForm',
|
||||
|
@ -74,9 +106,10 @@ export default {
|
|||
mixins: [form, fieldSubForm],
|
||||
data() {
|
||||
return {
|
||||
allowedValues: ['ai_prompt', 'ai_file_field_id'],
|
||||
allowedValues: ['ai_prompt', 'ai_file_field_id', 'ai_output_type'],
|
||||
values: {
|
||||
ai_prompt: '',
|
||||
ai_output_type: TextAIFieldOutputType.getType(),
|
||||
ai_file_field_id: null,
|
||||
},
|
||||
fileFieldSupported: false,
|
||||
|
@ -113,6 +146,19 @@ export default {
|
|||
return t.canRepresentFiles(field)
|
||||
})
|
||||
},
|
||||
outputTypes() {
|
||||
return Object.values(this.$registry.getAll('aiFieldOutputType'))
|
||||
},
|
||||
outputType() {
|
||||
return this.$registry.get('aiFieldOutputType', this.values.ai_output_type)
|
||||
},
|
||||
changedOutputType() {
|
||||
return (
|
||||
this.defaultValues.id &&
|
||||
this.defaultValues.type === this.values.type &&
|
||||
this.defaultValues.ai_output_type !== this.values.ai_output_type
|
||||
)
|
||||
},
|
||||
},
|
||||
methods: {
|
||||
setFileFieldSupported(generativeAIType) {
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
<template>
|
||||
<div class="control__elements">
|
||||
<FormTextarea
|
||||
ref="input"
|
||||
v-model="value"
|
||||
type="text"
|
||||
class="margin-bottom-2"
|
||||
:rows="6"
|
||||
:disabled="true"
|
||||
/>
|
||||
|
||||
<template v-if="!readOnly">
|
||||
<div>
|
||||
<component
|
||||
:is="outputRowEditFieldComponent"
|
||||
ref="field"
|
||||
v-bind="$props"
|
||||
:read-only="true"
|
||||
></component>
|
||||
<div v-if="!readOnly" class="margin-top-2">
|
||||
<Button
|
||||
v-if="isDeactivated && rowIsCreated"
|
||||
type="secondary"
|
||||
|
@ -19,7 +16,7 @@
|
|||
{{ $t('rowEditFieldAI.generate') }}
|
||||
</Button>
|
||||
<Button
|
||||
v-if="rowIsCreated"
|
||||
v-else-if="rowIsCreated"
|
||||
type="secondary"
|
||||
:loading="generating"
|
||||
@click="generate()"
|
||||
|
@ -33,7 +30,7 @@
|
|||
:workspace="workspace"
|
||||
:name="fieldName"
|
||||
></component>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
|
@ -48,6 +45,12 @@ export default {
|
|||
fieldName() {
|
||||
return this.$registry.get('field', this.field.type).getName()
|
||||
},
|
||||
outputRowEditFieldComponent() {
|
||||
return this.$registry
|
||||
.get('aiFieldOutputType', this.field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getRowEditFieldComponent(this.field)
|
||||
},
|
||||
},
|
||||
}
|
||||
</script>
|
||||
|
|
|
@ -23,9 +23,15 @@
|
|||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div v-else class="grid-view__cell grid-field-long-text__cell">
|
||||
<div class="grid-field-long-text">{{ props.value }}</div>
|
||||
</div>
|
||||
<component
|
||||
:is="$options.methods.getFunctionalOutputFieldComponent(parent, props)"
|
||||
v-else
|
||||
:workspace-id="props.workspaceId"
|
||||
:field="props.field"
|
||||
:value="props.value"
|
||||
:state="props.state"
|
||||
:read-only="props.readOnly"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
|
@ -41,6 +47,12 @@ export default {
|
|||
.get('field', AIFieldType.getType())
|
||||
.isDeactivated(props.workspaceId)
|
||||
},
|
||||
getFunctionalOutputFieldComponent(parent, props) {
|
||||
return parent.$registry
|
||||
.get('aiFieldOutputType', props.field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getFunctionalGridViewFieldComponent(props.field)
|
||||
},
|
||||
},
|
||||
}
|
||||
</script>
|
||||
|
|
|
@ -18,40 +18,34 @@
|
|||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
<component
|
||||
:is="outputGridViewFieldComponent"
|
||||
v-else
|
||||
ref="cell"
|
||||
class="grid-view__cell grid-field-long-text__cell active"
|
||||
:class="{ editing: opened }"
|
||||
@keyup.enter="opened = true"
|
||||
v-bind="$props"
|
||||
:read-only="true"
|
||||
>
|
||||
<div v-if="!opened" class="grid-field-long-text">{{ value }}</div>
|
||||
<template v-else>
|
||||
<div class="grid-field-long-text__textarea">
|
||||
{{ value }}
|
||||
</div>
|
||||
<template v-if="!readOnly && editing" #default="{ editing }">
|
||||
<div style="background-color: #fff; padding: 8px">
|
||||
<template v-if="!readOnly">
|
||||
<ButtonText
|
||||
v-if="!isDeactivated"
|
||||
icon="iconoir-magic-wand"
|
||||
:disabled="!modelAvailable || generating"
|
||||
:loading="generating"
|
||||
@click.prevent.stop="generate()"
|
||||
>
|
||||
{{ $t('gridViewFieldAI.regenerate') }}
|
||||
</ButtonText>
|
||||
<ButtonText
|
||||
v-else
|
||||
icon="iconoir-lock"
|
||||
@click.prevent.stop="$refs.clickModal.show()"
|
||||
>
|
||||
{{ $t('gridViewFieldAI.regenerate') }}
|
||||
</ButtonText>
|
||||
</template>
|
||||
<ButtonText
|
||||
v-if="!isDeactivated"
|
||||
icon="iconoir-magic-wand"
|
||||
:disabled="!modelAvailable || generating"
|
||||
:loading="generating"
|
||||
@click.prevent.stop="generate()"
|
||||
>
|
||||
{{ $t('gridViewFieldAI.regenerate') }}
|
||||
</ButtonText>
|
||||
<ButtonText
|
||||
v-else
|
||||
icon="iconoir-lock"
|
||||
@click.prevent.stop="$refs.clickModal.show()"
|
||||
>
|
||||
{{ $t('gridViewFieldAI.regenerate') }}
|
||||
</ButtonText>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</component>
|
||||
<component
|
||||
:is="deactivatedClickComponent"
|
||||
v-if="isDeactivated && workspace"
|
||||
|
@ -75,6 +69,12 @@ export default {
|
|||
fieldName() {
|
||||
return this.$registry.get('field', this.field.type).getName()
|
||||
},
|
||||
outputGridViewFieldComponent() {
|
||||
return this.$registry
|
||||
.get('aiFieldOutputType', this.field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getGridViewFieldComponent(this.field)
|
||||
},
|
||||
},
|
||||
methods: {
|
||||
save() {
|
||||
|
|
|
@ -2,13 +2,6 @@ import {
|
|||
FieldType,
|
||||
FormulaFieldType,
|
||||
} from '@baserow/modules/database/fieldTypes'
|
||||
import RowHistoryFieldText from '@baserow/modules/database/components/row/RowHistoryFieldText'
|
||||
import RowCardFieldText from '@baserow/modules/database/components/card/RowCardFieldText'
|
||||
import { collatedStringCompare } from '@baserow/modules/core/utils/string'
|
||||
import {
|
||||
genericContainsFilter,
|
||||
genericContainsWordFilter,
|
||||
} from '@baserow/modules/database/utils/fieldFilters'
|
||||
|
||||
import GridViewFieldAI from '@baserow_premium/components/views/grid/fields/GridViewFieldAI'
|
||||
import FunctionalGridViewFieldAI from '@baserow_premium/components/views/grid/fields/FunctionalGridViewFieldAI'
|
||||
|
@ -33,6 +26,10 @@ export class AIFieldType extends FieldType {
|
|||
return i18n.t('premiumFieldType.ai')
|
||||
}
|
||||
|
||||
getIsReadOnly() {
|
||||
return true
|
||||
}
|
||||
|
||||
getGridViewFieldComponent() {
|
||||
return GridViewFieldAI
|
||||
}
|
||||
|
@ -45,18 +42,21 @@ export class AIFieldType extends FieldType {
|
|||
return RowEditFieldAI
|
||||
}
|
||||
|
||||
getCardComponent() {
|
||||
return RowCardFieldText
|
||||
}
|
||||
|
||||
getRowHistoryEntryComponent() {
|
||||
return RowHistoryFieldText
|
||||
}
|
||||
|
||||
getFormComponent() {
|
||||
return FieldAISubForm
|
||||
}
|
||||
|
||||
getCardComponent(field) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getCardComponent(field)
|
||||
}
|
||||
|
||||
getRowHistoryEntryComponent(field) {
|
||||
return null
|
||||
}
|
||||
|
||||
getFormViewFieldComponents(field) {
|
||||
return {}
|
||||
}
|
||||
|
@ -65,17 +65,32 @@ export class AIFieldType extends FieldType {
|
|||
return null
|
||||
}
|
||||
|
||||
getSort(name, order) {
|
||||
return (a, b) => {
|
||||
const stringA = a[name] === null ? '' : '' + a[name]
|
||||
const stringB = b[name] === null ? '' : '' + b[name]
|
||||
getCardValueHeight(field) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getCardValueHeight(field)
|
||||
}
|
||||
|
||||
return collatedStringCompare(stringA, stringB, order)
|
||||
}
|
||||
getSort(name, order, field) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getSort(name, order, field)
|
||||
}
|
||||
|
||||
getCanSortInView(field) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getCanSortInView(field)
|
||||
}
|
||||
|
||||
getDocsDataType(field) {
|
||||
return 'string'
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getDocsDataType(field)
|
||||
}
|
||||
|
||||
getDocsDescription(field) {
|
||||
|
@ -84,15 +99,147 @@ export class AIFieldType extends FieldType {
|
|||
}
|
||||
|
||||
getDocsRequestExample(field) {
|
||||
return 'string'
|
||||
return 'read only'
|
||||
}
|
||||
|
||||
getContainsFilterFunction() {
|
||||
return genericContainsFilter
|
||||
getDocsResponseExample(field) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getDocsResponseExample(field)
|
||||
}
|
||||
|
||||
prepareValueForCopy(field, value) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.prepareValueForCopy(field, value)
|
||||
}
|
||||
|
||||
getContainsFilterFunction(field) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getContainsFilterFunction(field)
|
||||
}
|
||||
|
||||
getContainsWordFilterFunction(field) {
|
||||
return genericContainsWordFilter
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getContainsWordFilterFunction(field)
|
||||
}
|
||||
|
||||
toHumanReadableString(field, value) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.toHumanReadableString(field, value)
|
||||
}
|
||||
|
||||
getSortIndicator(field, registry) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getSortIndicator(field, registry)
|
||||
}
|
||||
|
||||
canRepresentDate(field) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.canRepresentDate(field)
|
||||
}
|
||||
|
||||
getCanGroupByInView(field) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getCanGroupByInView(field)
|
||||
}
|
||||
|
||||
parseInputValue(field, value) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.parseInputValue(field, value)
|
||||
}
|
||||
|
||||
canRepresentFiles(field) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.canRepresentFiles(field)
|
||||
}
|
||||
|
||||
getHasEmptyValueFilterFunction(field) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getHasEmptyValueFilterFunction(field)
|
||||
}
|
||||
|
||||
getHasValueContainsFilterFunction(field) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getHasValueContainsFilterFunction(field)
|
||||
}
|
||||
|
||||
getHasValueContainsWordFilterFunction(field) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getHasValueContainsWordFilterFunction(field)
|
||||
}
|
||||
|
||||
getHasValueLengthIsLowerThanFilterFunction(field) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getHasValueLengthIsLowerThanFilterFunction(field)
|
||||
}
|
||||
|
||||
getGroupByComponent(field) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getGroupByComponent(field)
|
||||
}
|
||||
|
||||
getGroupByIndicator(field, registry) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getGroupByIndicator(field, registry)
|
||||
}
|
||||
|
||||
getRowValueFromGroupValue(field, value) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getRowValueFromGroupValue(field, value)
|
||||
}
|
||||
|
||||
getGroupValueFromRowValue(field, value) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.getGroupValueFromRowValue(field, value)
|
||||
}
|
||||
|
||||
isEqual(field, value1, value2) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.isEqual(field, value1, value2)
|
||||
}
|
||||
|
||||
canBeReferencedByFormulaField(field) {
|
||||
return this.app.$registry
|
||||
.get('aiFieldOutputType', field.ai_output_type)
|
||||
.getBaserowFieldType()
|
||||
.canBeReferencedByFormulaField(field)
|
||||
}
|
||||
|
||||
getGridViewContextItemsOnCellsSelection(field) {
|
||||
|
|
|
@ -409,7 +409,10 @@
|
|||
"promptPlaceholder": "What is Baserow?",
|
||||
"premiumFeature": "The AI field is a premium feature",
|
||||
"emptyFileField": "None",
|
||||
"fileFieldHelp": "The first compatible file in the field will be used as the knowledge base for the prompt. The file has to be a text file with the supported file extension like .txt, .md, .pdf, .docx."
|
||||
"fileFieldHelp": "The first compatible file in the field will be used as the knowledge base for the prompt. The file has to be a text file with the supported file extension like .txt, .md, .pdf, .docx.",
|
||||
"outputType": "Output type",
|
||||
"outputTypeTooltip": "Select your desired output format to guide the LLM in generating responses that match your desired options.",
|
||||
"outputTypeChangedWarning": "Clears the generated cell values."
|
||||
},
|
||||
"rowEditFieldAI": {
|
||||
"generate": "Generate",
|
||||
|
@ -426,5 +429,11 @@
|
|||
"formulaFieldAI": {
|
||||
"generateWithAI": "Generate using AI",
|
||||
"featureName": "Generate formula using AI"
|
||||
},
|
||||
"aiOutputType": {
|
||||
"text": "Text",
|
||||
"textDescription": "Generates free text based on the prompt.",
|
||||
"choice": "Choice",
|
||||
"choiceDescription": "Chooses only one of the field options."
|
||||
}
|
||||
}
|
||||
|
|
|
@ -52,6 +52,10 @@ import {
|
|||
AIFieldType,
|
||||
PremiumFormulaFieldType,
|
||||
} from '@baserow_premium/fieldTypes'
|
||||
import {
|
||||
ChoiceAIFieldOutputType,
|
||||
TextAIFieldOutputType,
|
||||
} from '@baserow_premium/aiFieldOutputTypes'
|
||||
|
||||
export default (context) => {
|
||||
const { store, app, isDev } = context
|
||||
|
@ -94,6 +98,8 @@ export default (context) => {
|
|||
store.registerModule('template/view/timeline', timelineStore)
|
||||
store.registerModule('impersonating', impersonatingStore)
|
||||
|
||||
app.$registry.registerNamespace('aiFieldOutputType')
|
||||
|
||||
app.$registry.register('plugin', new PremiumPlugin(context))
|
||||
app.$registry.register('admin', new DashboardType(context))
|
||||
app.$registry.register('admin', new UsersAdminType(context))
|
||||
|
@ -160,4 +166,13 @@ export default (context) => {
|
|||
'rowModalSidebar',
|
||||
new CommentsRowModalSidebarType(context)
|
||||
)
|
||||
|
||||
app.$registry.register(
|
||||
'aiFieldOutputType',
|
||||
new TextAIFieldOutputType(context)
|
||||
)
|
||||
app.$registry.register(
|
||||
'aiFieldOutputType',
|
||||
new ChoiceAIFieldOutputType(context)
|
||||
)
|
||||
}
|
||||
|
|
|
@ -180,6 +180,7 @@ export class KanbanViewType extends PremiumViewType {
|
|||
row,
|
||||
values,
|
||||
metadata,
|
||||
updatedFieldIds,
|
||||
storePrefix = ''
|
||||
) {
|
||||
if (this.isCurrentView(store, tableId)) {
|
||||
|
@ -386,6 +387,7 @@ export class CalendarViewType extends PremiumViewType {
|
|||
row,
|
||||
values,
|
||||
metadata,
|
||||
updatedFieldIds,
|
||||
storePrefix = ''
|
||||
) {
|
||||
if (this.isCurrentView(store, tableId)) {
|
||||
|
|
|
@ -1,3 +1,13 @@
|
|||
.grid-field-single-select__cell {
|
||||
&.active {
|
||||
bottom: auto;
|
||||
right: auto;
|
||||
height: auto;
|
||||
min-height: calc(100% + 4px);
|
||||
min-width: calc(100% + 4px);
|
||||
}
|
||||
}
|
||||
|
||||
.grid-field-single-select {
|
||||
position: relative;
|
||||
display: block;
|
||||
|
|
|
@ -84,7 +84,7 @@ export default {
|
|||
const isNotThisField = f.id !== this.defaultValues.id
|
||||
const canBeReferencedByFormulaField = this.$registry
|
||||
.get('field', f.type)
|
||||
.canBeReferencedByFormulaField()
|
||||
.canBeReferencedByFormulaField(f)
|
||||
return isNotThisField && canBeReferencedByFormulaField
|
||||
})
|
||||
},
|
||||
|
|
|
@ -114,7 +114,7 @@ export default {
|
|||
.filter((f) => {
|
||||
return this.$registry
|
||||
.get('field', f.type)
|
||||
.canBeReferencedByFormulaField()
|
||||
.canBeReferencedByFormulaField(f)
|
||||
})
|
||||
.filter((f) => {
|
||||
return this.$hasPermission(
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
<template functional>
|
||||
<div ref="cell" class="grid-view__cell" :class="data.staticClass || ''">
|
||||
<div
|
||||
ref="cell"
|
||||
class="grid-view__cell grid-field-single-select__cell"
|
||||
:class="data.staticClass || ''"
|
||||
>
|
||||
<div class="grid-field-single-select">
|
||||
<div
|
||||
v-if="props.value"
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
class="grid-field-long-text__textarea"
|
||||
/>
|
||||
<div v-else class="grid-field-long-text__textarea">{{ value }}</div>
|
||||
<slot name="default" :slot-props="{ editing, opened }"></slot>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
<template>
|
||||
<div ref="cell" class="grid-view__cell active">
|
||||
<div ref="cell" class="grid-view__cell grid-field-single-select__cell active">
|
||||
<div
|
||||
ref="dropdownLink"
|
||||
class="grid-field-single-select grid-field-single-select--selected"
|
||||
|
@ -18,6 +18,7 @@
|
|||
class="iconoir-nav-arrow-down grid-field-single-select__icon"
|
||||
></i>
|
||||
</div>
|
||||
<slot name="default" :slot-props="{ editing, opened: true }"></slot>
|
||||
<FieldSelectOptionsDropdown
|
||||
v-if="!readOnly"
|
||||
ref="dropdown"
|
||||
|
|
|
@ -727,7 +727,7 @@ export class FieldType extends Registerable {
|
|||
* Override and return true if the field type can be referenced by a formula field.
|
||||
* @return {boolean}
|
||||
*/
|
||||
canBeReferencedByFormulaField() {
|
||||
canBeReferencedByFormulaField(field) {
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
|
@ -202,6 +202,7 @@ export const registerRealtimeEvents = (realtime) => {
|
|||
rowBeforeUpdate,
|
||||
row,
|
||||
data.metadata[row.id],
|
||||
data.updated_field_ids,
|
||||
'page/'
|
||||
)
|
||||
}
|
||||
|
|
|
@ -2506,7 +2506,7 @@ export const actions = {
|
|||
*/
|
||||
async updatedExistingRow(
|
||||
{ commit, getters, dispatch },
|
||||
{ view, fields, row, values, metadata }
|
||||
{ view, fields, row, values, metadata, updatedFieldIds = [] }
|
||||
) {
|
||||
const oldRow = clone(row)
|
||||
const newRow = Object.assign(clone(row), values)
|
||||
|
@ -2646,15 +2646,21 @@ export const actions = {
|
|||
// sure the loading state will stop if the value is updated. This is done even
|
||||
// if the row is not found in the buffer because it could have been removed from
|
||||
// the buffer when scrolling outside the buffer range.
|
||||
const updatedFieldIds = Object.entries(values)
|
||||
const getFieldId = (key) => parseInt(key.split('_')[1])
|
||||
const fieldIdsToClearPendingOperationsFor = Object.entries(values)
|
||||
.filter(
|
||||
([key, value]) =>
|
||||
key.startsWith('field_') && !_.isEqual(value, oldRow[key])
|
||||
key.startsWith('field_') &&
|
||||
// Either the value has changed.
|
||||
(_.isEqual(value, oldRow[key]) ||
|
||||
// Or the backend has just recalculated the value, even if it hasn't
|
||||
// actually changed.
|
||||
updatedFieldIds.includes(getFieldId(key)))
|
||||
)
|
||||
.map(([key, value]) => parseInt(key.split('_')[1]))
|
||||
.map(([key, value]) => getFieldId(key))
|
||||
|
||||
commit('CLEAR_PENDING_FIELD_OPERATIONS', {
|
||||
fieldIds: updatedFieldIds,
|
||||
fieldIds: fieldIdsToClearPendingOperationsFor,
|
||||
rowId: row.id,
|
||||
})
|
||||
|
||||
|
|
|
@ -267,7 +267,16 @@ export class ViewType extends Registerable {
|
|||
* via a real time event by another user. It can be used to check if data in an store
|
||||
* needs to be updated.
|
||||
*/
|
||||
rowUpdated(context, tableId, fields, row, values, metadata, storePrefix) {}
|
||||
rowUpdated(
|
||||
context,
|
||||
tableId,
|
||||
fields,
|
||||
row,
|
||||
values,
|
||||
metadata,
|
||||
updatedFieldIds,
|
||||
storePrefix
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Event that is called when something went wrong while generating AI values
|
||||
|
@ -722,6 +731,7 @@ export class GridViewType extends ViewType {
|
|||
row,
|
||||
values,
|
||||
metadata,
|
||||
updatedFieldIds,
|
||||
storePrefix = ''
|
||||
) {
|
||||
if (this.isCurrentView(store, tableId)) {
|
||||
|
@ -731,6 +741,7 @@ export class GridViewType extends ViewType {
|
|||
row,
|
||||
values,
|
||||
metadata,
|
||||
updatedFieldIds,
|
||||
})
|
||||
await store.dispatch(storePrefix + 'view/grid/fetchByScrollTopDelayed', {
|
||||
scrollTop: store.getters[storePrefix + 'view/grid/getScrollTop'],
|
||||
|
@ -971,6 +982,7 @@ export const BaseBufferedRowViewTypeMixin = (Base) =>
|
|||
row,
|
||||
values,
|
||||
metadata,
|
||||
updatedFieldIds,
|
||||
storePrefix = ''
|
||||
) {
|
||||
if (this.isCurrentView(store, tableId)) {
|
||||
|
|
Loading…
Add table
Reference in a new issue