mirror of
https://gitlab.com/bramw/baserow.git
synced 2025-04-17 18:32:35 +00:00
AI field
This commit is contained in:
parent
7021acdb64
commit
1298aa7eec
98 changed files with 2393 additions and 108 deletions
backend
Makefilepytest.ini
requirements
src/baserow
api
config/settings
contrib/database
core
test_utils
tests/baserow/contrib
database
api
export
field
management
rows
integrations/local_baserow
changelog/entries/unreleased/feature
docker-compose.ymlpremium
backend/tests/baserow_premium_tests/export
web-frontend/modules/baserow_premium/components/views
web-frontend
locales
modules
core
database
components
field
ChooseSingleSelectField.vueCreateFieldContext.vueDuplicateFieldModal.vueFieldAISubForm.vueFieldContext.vueFieldCountSubForm.vueFieldForm.vueFieldFormulaSubForm.vueFieldLookupSubForm.vueFieldRollupSubForm.vueInsertFieldContext.vueUpdateFieldContext.vue
formula
row
ForeignRowEditModal.vueRowCreateModal.vueRowEditFieldAI.vueRowEditModal.vueRowEditModalField.vueRowEditModalFieldsList.vueSelectRowContent.vue
view
locales
mixins
plugin.jsrealtime.jsservices
store/view
viewTypes.jstest/unit/database
|
@ -21,7 +21,7 @@ format:
|
|||
fix: sort format
|
||||
|
||||
sort:
|
||||
isort --skip generated src tests ../premium/backend ../enterprise/backend || exit;
|
||||
isort --skip generated --overwrite-in-place src tests ../premium/backend ../enterprise/backend || exit;
|
||||
|
||||
test:
|
||||
pytest tests ../premium/backend/tests ../enterprise/backend/tests || exit;
|
||||
|
|
|
@ -37,6 +37,7 @@ markers =
|
|||
field_multiple_collaborators: All tests related to multiple collaborator field
|
||||
field_last_modified_by: All tests related to last modified by field
|
||||
field_autonumber: All tests related to autonumber field
|
||||
field_ai: All tests related to AI field
|
||||
view_ownership: All tests related to view ownership type
|
||||
view_calendar: All tests related to the calendar view
|
||||
api_rows: All tests to manipulate rows via HTTP API
|
||||
|
|
|
@ -66,3 +66,6 @@ https://github.com/fellowapp/prosemirror-py/archive/refs/tags/v0.3.5.zip
|
|||
rich==13.7.0
|
||||
tzdata==2023.3
|
||||
sentry-sdk==1.39.1
|
||||
openai==1.9.0
|
||||
typing_extensions==4.7.1
|
||||
ollama==0.1.5
|
||||
|
|
|
@ -8,10 +8,15 @@ advocate==1.0.0
|
|||
# via -r base.in
|
||||
amqp==5.1.1
|
||||
# via kombu
|
||||
annotated-types==0.6.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via -r base.in
|
||||
anyio==3.6.1
|
||||
# via watchfiles
|
||||
# via
|
||||
# httpx
|
||||
# openai
|
||||
# watchfiles
|
||||
asgiref==3.6.0
|
||||
# via
|
||||
# -r base.in
|
||||
|
@ -66,6 +71,8 @@ celery-singleton==0.3.1
|
|||
# via -r base.in
|
||||
certifi==2023.7.22
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==1.15.1
|
||||
|
@ -110,6 +117,8 @@ deprecated==1.2.13
|
|||
# via
|
||||
# opentelemetry-api
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
dj-database-url==1.3.0
|
||||
# via -r base.in
|
||||
django==4.1.13
|
||||
|
@ -183,9 +192,17 @@ googleapis-common-protos==1.58.0
|
|||
gunicorn==20.1.0
|
||||
# via -r base.in
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
httpcore==1.0.2
|
||||
# via httpx
|
||||
httptools==0.5.0
|
||||
# via uvicorn
|
||||
httpx==0.25.2
|
||||
# via
|
||||
# ollama
|
||||
# openai
|
||||
hyperlink==21.0.0
|
||||
# via
|
||||
# autobahn
|
||||
|
@ -193,6 +210,7 @@ hyperlink==21.0.0
|
|||
idna==3.4
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# hyperlink
|
||||
# requests
|
||||
# twisted
|
||||
|
@ -230,6 +248,10 @@ netifaces==0.11.0
|
|||
# via advocate
|
||||
oauthlib==3.2.2
|
||||
# via requests-oauthlib
|
||||
ollama==0.1.5
|
||||
# via -r base.in
|
||||
openai==1.9.0
|
||||
# via -r base.in
|
||||
opentelemetry-api==1.21.0
|
||||
# via
|
||||
# -r base.in
|
||||
|
@ -360,6 +382,10 @@ pyasn1-modules==0.2.8
|
|||
# service-identity
|
||||
pycparser==2.21
|
||||
# via cffi
|
||||
pydantic==2.5.3
|
||||
# via openai
|
||||
pydantic-core==2.14.6
|
||||
# via pydantic
|
||||
pygments==2.17.2
|
||||
# via rich
|
||||
pyjwt==2.5.0
|
||||
|
@ -443,26 +469,35 @@ six==1.16.0
|
|||
# python-dateutil
|
||||
# service-identity
|
||||
sniffio==1.3.0
|
||||
# via anyio
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# openai
|
||||
sqlparse==0.4.4
|
||||
# via django
|
||||
tenacity==8.1.0
|
||||
# via celery-redbeat
|
||||
tqdm==4.65.0
|
||||
# via -r base.in
|
||||
# via
|
||||
# -r base.in
|
||||
# openai
|
||||
twisted[tls]==23.10.0
|
||||
# via
|
||||
# -r base.in
|
||||
# daphne
|
||||
txaio==23.1.1
|
||||
# via autobahn
|
||||
typing-extensions==4.4.0
|
||||
typing-extensions==4.7.1
|
||||
# via
|
||||
# -r base.in
|
||||
# azure-core
|
||||
# azure-storage-blob
|
||||
# dj-database-url
|
||||
# openai
|
||||
# opentelemetry-sdk
|
||||
# prosemirror
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# twisted
|
||||
# uvicorn
|
||||
tzdata==2023.3
|
||||
|
|
|
@ -345,7 +345,7 @@ types-requests==2.28.11.2
|
|||
# via djangorestframework-stubs
|
||||
types-urllib3==1.26.25
|
||||
# via types-requests
|
||||
typing-extensions==4.4.0
|
||||
typing-extensions==4.7.1
|
||||
# via
|
||||
# -c base.txt
|
||||
# black
|
||||
|
|
0
backend/src/baserow/api/generative_ai/__init__.py
Normal file
0
backend/src/baserow/api/generative_ai/__init__.py
Normal file
12
backend/src/baserow/api/generative_ai/errors.py
Normal file
12
backend/src/baserow/api/generative_ai/errors.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
from rest_framework.status import HTTP_400_BAD_REQUEST
|
||||
|
||||
ERROR_GENERATIVE_AI_DOES_NOT_EXIST = (
|
||||
"ERROR_GENERATIVE_AI_DOES_NOT_EXIST",
|
||||
HTTP_400_BAD_REQUEST,
|
||||
"The requested generative AI does not exist.",
|
||||
)
|
||||
ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE = (
|
||||
"ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE",
|
||||
HTTP_400_BAD_REQUEST,
|
||||
"The requested model does not belong to the provided type.",
|
||||
)
|
|
@ -1,6 +1,7 @@
|
|||
from rest_framework import serializers
|
||||
|
||||
from baserow.api.user_files.serializers import UserFileField
|
||||
from baserow.core.generative_ai.registries import generative_ai_model_type_registry
|
||||
from baserow.core.models import Settings
|
||||
|
||||
|
||||
|
@ -23,6 +24,7 @@ class SettingsSerializer(serializers.ModelSerializer):
|
|||
required=False,
|
||||
help_text="Co-branding logo that's placed next to the Baserow logo (176x29).",
|
||||
)
|
||||
generative_ai = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = Settings
|
||||
|
@ -38,6 +40,7 @@ class SettingsSerializer(serializers.ModelSerializer):
|
|||
"track_workspace_usage",
|
||||
"show_baserow_help_request",
|
||||
"co_branding_logo",
|
||||
"generative_ai",
|
||||
)
|
||||
extra_kwargs = {
|
||||
"allow_new_signups": {"required": False},
|
||||
|
@ -49,6 +52,9 @@ class SettingsSerializer(serializers.ModelSerializer):
|
|||
"show_baserow_help_request": {"required": False},
|
||||
}
|
||||
|
||||
def get_generative_ai(self, object):
|
||||
return generative_ai_model_type_registry.get_models_per_type()
|
||||
|
||||
|
||||
class InstanceIdSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
|
|
|
@ -1206,3 +1206,15 @@ if SENTRY_DSN:
|
|||
send_default_pii=False,
|
||||
environment=os.getenv("SENTRY_ENVIRONMENT", ""),
|
||||
)
|
||||
|
||||
BASEROW_OPENAI_API_KEY = os.getenv("BASEROW_OPENAI_API_KEY", None)
|
||||
BASEROW_OPENAI_ORGANIZATION = os.getenv("BASEROW_OPENAI_ORGANIZATION", "") or None
|
||||
BASEROW_OPENAI_MODELS = os.getenv("BASEROW_OPENAI_MODELS", "")
|
||||
BASEROW_OPENAI_MODELS = (
|
||||
BASEROW_OPENAI_MODELS.split(",") if BASEROW_OPENAI_MODELS else []
|
||||
)
|
||||
BASEROW_OLLAMA_HOST = os.getenv("BASEROW_OLLAMA_HOST", None)
|
||||
BASEROW_OLLAMA_MODELS = os.getenv("BASEROW_OLLAMA_MODELS", "")
|
||||
BASEROW_OLLAMA_MODELS = (
|
||||
BASEROW_OLLAMA_MODELS.split(",") if BASEROW_OLLAMA_MODELS else []
|
||||
)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.utils.functional import lazy
|
||||
|
@ -319,3 +320,14 @@ class PasswordSerializer(serializers.CharField):
|
|||
return None
|
||||
|
||||
return make_password(data)
|
||||
|
||||
|
||||
class GenerateAIFieldValueViewSerializer(serializers.Serializer):
|
||||
row_ids = serializers.ListField(
|
||||
child=serializers.IntegerField(),
|
||||
max_length=settings.BATCH_ROWS_SIZE_LIMIT,
|
||||
help_text="The ids of the rows that the values should be generated for.",
|
||||
)
|
||||
|
||||
def to_internal_value(self, data):
|
||||
return super().to_internal_value(data)
|
||||
|
|
|
@ -4,6 +4,7 @@ from baserow.contrib.database.fields.registries import field_type_registry
|
|||
|
||||
from .views import (
|
||||
AsyncDuplicateFieldView,
|
||||
AsyncGenerateAIFieldValuesView,
|
||||
FieldsView,
|
||||
FieldView,
|
||||
UniqueRowValueFieldView,
|
||||
|
@ -24,4 +25,9 @@ urlpatterns = field_type_registry.api_urls + [
|
|||
AsyncDuplicateFieldView.as_view(),
|
||||
name="async_duplicate",
|
||||
),
|
||||
re_path(
|
||||
r"(?P<field_id>[0-9]+)/generate-ai-field-values/$",
|
||||
AsyncGenerateAIFieldValuesView.as_view(),
|
||||
name="async_generate_ai_field_values",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -19,6 +19,10 @@ from baserow.api.decorators import (
|
|||
validate_query_parameters,
|
||||
)
|
||||
from baserow.api.errors import ERROR_USER_NOT_IN_GROUP
|
||||
from baserow.api.generative_ai.errors import (
|
||||
ERROR_GENERATIVE_AI_DOES_NOT_EXIST,
|
||||
ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE,
|
||||
)
|
||||
from baserow.api.jobs.errors import ERROR_MAX_JOB_COUNT_EXCEEDED
|
||||
from baserow.api.jobs.serializers import JobSerializer
|
||||
from baserow.api.schemas import (
|
||||
|
@ -46,6 +50,7 @@ from baserow.contrib.database.api.fields.errors import (
|
|||
ERROR_MAX_FIELD_COUNT_EXCEEDED,
|
||||
ERROR_RESERVED_BASEROW_FIELD_NAME,
|
||||
)
|
||||
from baserow.contrib.database.api.rows.errors import ERROR_ROW_DOES_NOT_EXIST
|
||||
from baserow.contrib.database.api.tables.errors import (
|
||||
ERROR_FAILED_TO_LOCK_TABLE_DUE_TO_CONFLICT,
|
||||
ERROR_TABLE_DOES_NOT_EXIST,
|
||||
|
@ -75,13 +80,16 @@ from baserow.contrib.database.fields.exceptions import (
|
|||
)
|
||||
from baserow.contrib.database.fields.handler import FieldHandler
|
||||
from baserow.contrib.database.fields.job_types import DuplicateFieldJobType
|
||||
from baserow.contrib.database.fields.models import Field
|
||||
from baserow.contrib.database.fields.models import AIField, Field
|
||||
from baserow.contrib.database.fields.operations import (
|
||||
CreateFieldOperationType,
|
||||
ListFieldsOperationType,
|
||||
ReadFieldOperationType,
|
||||
)
|
||||
from baserow.contrib.database.fields.registries import field_type_registry
|
||||
from baserow.contrib.database.fields.tasks import generate_ai_values_for_rows
|
||||
from baserow.contrib.database.rows.exceptions import RowDoesNotExist
|
||||
from baserow.contrib.database.rows.handler import RowHandler
|
||||
from baserow.contrib.database.table.exceptions import (
|
||||
FailedToLockTableDueToConflict,
|
||||
TableDoesNotExist,
|
||||
|
@ -92,6 +100,11 @@ from baserow.contrib.database.tokens.handler import TokenHandler
|
|||
from baserow.core.action.registries import action_type_registry
|
||||
from baserow.core.db import specific_iterator
|
||||
from baserow.core.exceptions import UserNotInWorkspace
|
||||
from baserow.core.generative_ai.exceptions import (
|
||||
GenerativeAITypeDoesNotExist,
|
||||
ModelDoesNotBelongToType,
|
||||
)
|
||||
from baserow.core.generative_ai.registries import generative_ai_model_type_registry
|
||||
from baserow.core.handler import CoreHandler
|
||||
from baserow.core.jobs.exceptions import MaxJobCountExceeded
|
||||
from baserow.core.jobs.handler import JobHandler
|
||||
|
@ -103,6 +116,7 @@ from .serializers import (
|
|||
DuplicateFieldParamsSerializer,
|
||||
FieldSerializer,
|
||||
FieldSerializerWithRelatedFields,
|
||||
GenerateAIFieldValueViewSerializer,
|
||||
RelatedFieldsSerializer,
|
||||
UniqueRowValueParamsSerializer,
|
||||
UniqueRowValuesSerializer,
|
||||
|
@ -600,3 +614,87 @@ class AsyncDuplicateFieldView(APIView):
|
|||
|
||||
serializer = job_type_registry.get_serializer(job, JobSerializer)
|
||||
return Response(serializer.data, status=status.HTTP_202_ACCEPTED)
|
||||
|
||||
|
||||
class AsyncGenerateAIFieldValuesView(APIView):
|
||||
permission_classes = (IsAuthenticated,)
|
||||
|
||||
@extend_schema(
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
name="field_id",
|
||||
location=OpenApiParameter.PATH,
|
||||
type=OpenApiTypes.INT,
|
||||
description="The field to generate the value for.",
|
||||
),
|
||||
CLIENT_SESSION_ID_SCHEMA_PARAMETER,
|
||||
CLIENT_UNDO_REDO_ACTION_GROUP_ID_SCHEMA_PARAMETER,
|
||||
],
|
||||
tags=["Database table fields"],
|
||||
operation_id="generate_table_ai_field_value",
|
||||
description=(
|
||||
"Endpoint that's used by the AI field to start an sync task that "
|
||||
"will update the cell value of the provided row IDs based on the "
|
||||
"dynamically constructed prompt configured in the field settings."
|
||||
),
|
||||
request=None,
|
||||
responses={
|
||||
200: str,
|
||||
400: get_error_schema(
|
||||
[
|
||||
"ERROR_GENERATIVE_AI_DOES_NOT_EXIST",
|
||||
"ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE",
|
||||
]
|
||||
),
|
||||
404: get_error_schema(
|
||||
[
|
||||
"ERROR_FIELD_DOES_NOT_EXIST",
|
||||
"ERROR_ROW_DOES_NOT_EXIST",
|
||||
]
|
||||
),
|
||||
},
|
||||
)
|
||||
@transaction.atomic
|
||||
@map_exceptions(
|
||||
{
|
||||
FieldDoesNotExist: ERROR_FIELD_DOES_NOT_EXIST,
|
||||
RowDoesNotExist: ERROR_ROW_DOES_NOT_EXIST,
|
||||
UserNotInWorkspace: ERROR_USER_NOT_IN_GROUP,
|
||||
GenerativeAITypeDoesNotExist: ERROR_GENERATIVE_AI_DOES_NOT_EXIST,
|
||||
ModelDoesNotBelongToType: ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE,
|
||||
}
|
||||
)
|
||||
@validate_body(GenerateAIFieldValueViewSerializer, return_validated=True)
|
||||
def post(self, request: Request, field_id: int, data) -> Response:
|
||||
ai_field = FieldHandler().get_field(
|
||||
field_id,
|
||||
base_queryset=AIField.objects.all().select_related(
|
||||
"table__database__workspace"
|
||||
),
|
||||
)
|
||||
|
||||
CoreHandler().check_permissions(
|
||||
request.user,
|
||||
ListFieldsOperationType.type,
|
||||
workspace=ai_field.table.database.workspace,
|
||||
context=ai_field.table,
|
||||
)
|
||||
|
||||
model = ai_field.table.get_model()
|
||||
req_row_ids = data["row_ids"]
|
||||
rows = RowHandler().get_rows(model, req_row_ids)
|
||||
if len(rows) != len(req_row_ids):
|
||||
found_rows_ids = [row.id for row in rows]
|
||||
raise RowDoesNotExist(sorted(list(set(req_row_ids) - set(found_rows_ids))))
|
||||
|
||||
generative_ai_model_type = generative_ai_model_type_registry.get(
|
||||
ai_field.ai_generative_ai_type
|
||||
)
|
||||
ai_models = generative_ai_model_type.get_enabled_models()
|
||||
|
||||
if ai_field.ai_generative_ai_model not in ai_models:
|
||||
raise ModelDoesNotBelongToType(model_name=ai_field.ai_generative_ai_model)
|
||||
|
||||
generate_ai_values_for_rows.delay(request.user.id, ai_field.id, req_row_ids)
|
||||
|
||||
return Response(status=status.HTTP_202_ACCEPTED)
|
||||
|
|
|
@ -162,6 +162,7 @@ class DatabaseConfig(AppConfig):
|
|||
plugin_registry.register(DatabasePlugin())
|
||||
|
||||
from .fields.field_types import (
|
||||
AIFieldType,
|
||||
AutonumberFieldType,
|
||||
BooleanFieldType,
|
||||
CountFieldType,
|
||||
|
@ -216,8 +217,10 @@ class DatabaseConfig(AppConfig):
|
|||
field_type_registry.register(UUIDFieldType())
|
||||
field_type_registry.register(AutonumberFieldType())
|
||||
field_type_registry.register(PasswordFieldType())
|
||||
field_type_registry.register(AIFieldType())
|
||||
|
||||
from .fields.field_converters import (
|
||||
AIFieldConverter,
|
||||
AutonumberFieldConverter,
|
||||
FileFieldConverter,
|
||||
FormulaFieldConverter,
|
||||
|
@ -244,6 +247,7 @@ class DatabaseConfig(AppConfig):
|
|||
field_converter_registry.register(FormulaFieldConverter())
|
||||
field_converter_registry.register(AutonumberFieldConverter())
|
||||
field_converter_registry.register(PasswordFieldConverter())
|
||||
field_converter_registry.register(AIFieldConverter())
|
||||
|
||||
from .fields.actions import (
|
||||
CreateFieldActionType,
|
||||
|
@ -726,6 +730,16 @@ class DatabaseConfig(AppConfig):
|
|||
|
||||
subject_type_registry.register(TokenSubjectType())
|
||||
|
||||
from baserow.contrib.database.data_providers.registries import (
|
||||
database_data_provider_type_registry,
|
||||
)
|
||||
|
||||
from .rows.data_providers import HumanReadableFieldsDataProviderType
|
||||
|
||||
database_data_provider_type_registry.register(
|
||||
HumanReadableFieldsDataProviderType()
|
||||
)
|
||||
|
||||
# notification_types
|
||||
from baserow.contrib.database.fields.notification_types import (
|
||||
CollaboratorAddedToRowNotificationType,
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from baserow.core.formula.registries import DataProviderTypeRegistry
|
||||
|
||||
database_data_provider_type_registry = DataProviderTypeRegistry()
|
|
@ -10,15 +10,18 @@ from baserow.contrib.database.db.schema import (
|
|||
)
|
||||
|
||||
from .models import (
|
||||
AIField,
|
||||
AutonumberField,
|
||||
FileField,
|
||||
FormulaField,
|
||||
LinkRowField,
|
||||
LongTextField,
|
||||
MultipleCollaboratorsField,
|
||||
MultipleSelectField,
|
||||
PasswordField,
|
||||
SelectOption,
|
||||
SingleSelectField,
|
||||
TextField,
|
||||
)
|
||||
from .registries import FieldConverter, field_type_registry
|
||||
|
||||
|
@ -74,6 +77,16 @@ class PasswordFieldConverter(RecreateFieldConverter):
|
|||
return to_password or from_password
|
||||
|
||||
|
||||
class AIFieldConverter(RecreateFieldConverter):
|
||||
type = "ai"
|
||||
|
||||
def is_applicable(self, from_model, from_field, to_field):
|
||||
from_ai = isinstance(from_field, AIField)
|
||||
to_ai = isinstance(to_field, AIField)
|
||||
to_text_fields = isinstance(to_field, (TextField, LongTextField))
|
||||
return from_ai and not (to_text_fields or to_ai) or not from_ai and to_ai
|
||||
|
||||
|
||||
class LinkRowFieldConverter(RecreateFieldConverter):
|
||||
type = "link_row"
|
||||
|
||||
|
|
|
@ -235,6 +235,14 @@ def construct_all_possible_field_kwargs(
|
|||
"uuid": [{"name": "uuid"}],
|
||||
"autonumber": [{"name": "autonumber"}],
|
||||
"password": [{"name": "password"}],
|
||||
"ai": [
|
||||
{
|
||||
"name": "ai",
|
||||
"ai_generative_ai_type": "test_generative_ai",
|
||||
"ai_generative_ai_model": "test_1",
|
||||
"ai_prompt": "Who are you?",
|
||||
}
|
||||
],
|
||||
}
|
||||
# If you have added a new field please add an entry into the dict above with any
|
||||
# test worthy combinations of kwargs
|
||||
|
|
|
@ -40,6 +40,10 @@ from dateutil.parser import ParserError
|
|||
from loguru import logger
|
||||
from rest_framework import serializers
|
||||
|
||||
from baserow.api.generative_ai.errors import (
|
||||
ERROR_GENERATIVE_AI_DOES_NOT_EXIST,
|
||||
ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE,
|
||||
)
|
||||
from baserow.contrib.database.api.fields.errors import (
|
||||
ERROR_DATE_FORCE_TIMEZONE_OFFSET_ERROR,
|
||||
ERROR_INCOMPATIBLE_PRIMARY_FIELD_TYPE,
|
||||
|
@ -96,6 +100,12 @@ from baserow.core.expressions import DateTrunc
|
|||
from baserow.core.fields import SyncedDateTimeField
|
||||
from baserow.core.formula import BaserowFormulaException
|
||||
from baserow.core.formula.parser.exceptions import FormulaFunctionTypeDoesNotExist
|
||||
from baserow.core.formula.serializers import FormulaSerializerField
|
||||
from baserow.core.generative_ai.exceptions import (
|
||||
GenerativeAITypeDoesNotExist,
|
||||
ModelDoesNotBelongToType,
|
||||
)
|
||||
from baserow.core.generative_ai.registries import generative_ai_model_type_registry
|
||||
from baserow.core.handler import CoreHandler
|
||||
from baserow.core.models import UserFile, WorkspaceUser
|
||||
from baserow.core.registries import ImportExportConfig
|
||||
|
@ -155,6 +165,7 @@ from .fields import (
|
|||
from .handler import FieldHandler
|
||||
from .models import (
|
||||
AbstractSelectOption,
|
||||
AIField,
|
||||
AutonumberField,
|
||||
BooleanField,
|
||||
CountField,
|
||||
|
@ -5898,3 +5909,102 @@ class PasswordFieldType(FieldType):
|
|||
# We don't want to expose the hash of the password, so we just show `True` or
|
||||
# `False` as string depending on whether the value is set.
|
||||
return bool(value)
|
||||
|
||||
|
||||
class AIFieldType(CollationSortMixin, FieldType):
|
||||
"""
|
||||
The AI field can automatically query a generative AI model based on the provided
|
||||
prompt. It's possible to reference other fields to generate a unique output.
|
||||
"""
|
||||
|
||||
type = "ai"
|
||||
model_class = AIField
|
||||
can_be_in_form_view = False
|
||||
keep_data_on_duplication = True
|
||||
allowed_fields = ["ai_generative_ai_type", "ai_generative_ai_model", "ai_prompt"]
|
||||
serializer_field_names = [
|
||||
"ai_generative_ai_type",
|
||||
"ai_generative_ai_model",
|
||||
"ai_prompt",
|
||||
]
|
||||
serializer_field_overrides = {
|
||||
"ai_prompt": FormulaSerializerField(
|
||||
help_text="The prompt that must run for each row. Must be an formula.",
|
||||
required=False,
|
||||
allow_blank=True,
|
||||
default="",
|
||||
),
|
||||
}
|
||||
api_exceptions_map = {
|
||||
GenerativeAITypeDoesNotExist: ERROR_GENERATIVE_AI_DOES_NOT_EXIST,
|
||||
ModelDoesNotBelongToType: ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE,
|
||||
}
|
||||
can_get_unique_values = False
|
||||
|
||||
def get_serializer_field(self, instance, **kwargs):
|
||||
required = kwargs.get("required", False)
|
||||
return serializers.CharField(
|
||||
**{
|
||||
"required": required,
|
||||
"allow_null": not required,
|
||||
"allow_blank": not required,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
def get_model_field(self, instance, **kwargs):
|
||||
return models.TextField(null=True, **kwargs)
|
||||
|
||||
def get_serializer_help_text(self, instance):
|
||||
return (
|
||||
"Holds a text value that is generated by a generative UI model using a "
|
||||
"dynamic prompt."
|
||||
)
|
||||
|
||||
def random_value(self, instance, fake, cache):
|
||||
return fake.name()
|
||||
|
||||
def to_baserow_formula_type(self, field) -> BaserowFormulaType:
|
||||
return BaserowFormulaTextType(nullable=True)
|
||||
|
||||
def from_baserow_formula_type(
|
||||
self, formula_type: BaserowFormulaTextType
|
||||
) -> TextField:
|
||||
return TextField()
|
||||
|
||||
def get_value_for_filter(self, row: "GeneratedTableModel", field: Field) -> any:
|
||||
value = getattr(row, field.db_column)
|
||||
return collate_expression(Value(value))
|
||||
|
||||
def contains_query(self, *args):
|
||||
return contains_filter(*args)
|
||||
|
||||
def contains_word_query(self, *args):
|
||||
return contains_word_filter(*args)
|
||||
|
||||
def _validate_field_kwargs(self, ai_type, model_type):
|
||||
ai_type = generative_ai_model_type_registry.get(ai_type)
|
||||
models = ai_type.get_enabled_models()
|
||||
if model_type not in models:
|
||||
raise ModelDoesNotBelongToType(model_name=model_type)
|
||||
|
||||
def before_create(
|
||||
self, table, primary, allowed_field_values, order, user, field_kwargs
|
||||
):
|
||||
ai_type = field_kwargs.get("ai_generative_ai_type", None)
|
||||
model_type = field_kwargs.get("ai_generative_ai_model", None)
|
||||
self._validate_field_kwargs(ai_type, model_type)
|
||||
|
||||
def before_update(self, from_field, to_field_values, user, field_kwargs):
|
||||
update_field = None
|
||||
if isinstance(from_field, AIField):
|
||||
update_field = from_field
|
||||
|
||||
ai_type = field_kwargs.get("ai_generative_ai_type", None) or getattr(
|
||||
update_field, "ai_generative_ai_type", None
|
||||
)
|
||||
model_type = field_kwargs.get("ai_generative_ai_model", None) or getattr(
|
||||
update_field, "ai_generative_ai_model", None
|
||||
)
|
||||
|
||||
self._validate_field_kwargs(ai_type, model_type)
|
||||
|
|
|
@ -26,6 +26,7 @@ from baserow.contrib.database.table.constants import (
|
|||
MULTIPLE_SELECT_THROUGH_TABLE_PREFIX,
|
||||
get_tsv_vector_field_name,
|
||||
)
|
||||
from baserow.core.formula.field import FormulaField as ModelFormulaField
|
||||
from baserow.core.jobs.mixins import (
|
||||
JobWithUndoRedoIds,
|
||||
JobWithUserIpAddress,
|
||||
|
@ -740,6 +741,12 @@ class PasswordField(Field):
|
|||
pass
|
||||
|
||||
|
||||
class AIField(Field):
|
||||
ai_generative_ai_type = models.CharField(max_length=32, null=True)
|
||||
ai_generative_ai_model = models.CharField(max_length=32, null=True)
|
||||
ai_prompt = ModelFormulaField(default="")
|
||||
|
||||
|
||||
class DuplicateFieldJob(
|
||||
JobWithUserIpAddress, JobWithWebsocketId, JobWithUndoRedoIds, Job
|
||||
):
|
||||
|
|
|
@ -177,7 +177,6 @@ def notify_users_added_to_row_when_rows_updated(
|
|||
model,
|
||||
before_return,
|
||||
updated_field_ids,
|
||||
before_rows_values,
|
||||
m2m_change_tracker=None,
|
||||
**kwargs
|
||||
):
|
||||
|
|
|
@ -10,10 +10,25 @@ from loguru import logger
|
|||
from opentelemetry import trace
|
||||
|
||||
from baserow.config.celery import app
|
||||
from baserow.contrib.database.fields.handler import FieldHandler
|
||||
from baserow.contrib.database.fields.models import AIField
|
||||
from baserow.contrib.database.fields.operations import ListFieldsOperationType
|
||||
from baserow.contrib.database.fields.registries import field_type_registry
|
||||
from baserow.contrib.database.rows.exceptions import RowDoesNotExist
|
||||
from baserow.contrib.database.rows.handler import RowHandler
|
||||
from baserow.contrib.database.rows.runtime_formula_contexts import (
|
||||
HumanReadableRowContext,
|
||||
)
|
||||
from baserow.contrib.database.rows.signals import rows_ai_values_generation_error
|
||||
from baserow.contrib.database.search.handler import SearchHandler
|
||||
from baserow.core.formula import resolve_formula
|
||||
from baserow.core.formula.registries import formula_runtime_function_registry
|
||||
from baserow.core.generative_ai.exceptions import ModelDoesNotBelongToType
|
||||
from baserow.core.generative_ai.registries import generative_ai_model_type_registry
|
||||
from baserow.core.handler import CoreHandler
|
||||
from baserow.core.models import Workspace
|
||||
from baserow.core.telemetry.utils import add_baserow_trace_attrs, baserow_trace
|
||||
from baserow.core.user.handler import User
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
@ -105,6 +120,74 @@ def _run_periodic_field_type_update_per_workspace(
|
|||
SearchHandler().entire_field_values_changed_or_created(fields[0].table, fields)
|
||||
|
||||
|
||||
@app.task(bind=True, queue="export")
|
||||
def generate_ai_values_for_rows(self, user_id: int, field_id: int, row_ids: list[int]):
|
||||
user = User.objects.get(pk=user_id)
|
||||
|
||||
ai_field = FieldHandler().get_field(
|
||||
field_id,
|
||||
base_queryset=AIField.objects.all().select_related(
|
||||
"table__database__workspace"
|
||||
),
|
||||
)
|
||||
table = ai_field.table
|
||||
|
||||
CoreHandler().check_permissions(
|
||||
user,
|
||||
ListFieldsOperationType.type,
|
||||
workspace=table.database.workspace,
|
||||
context=table,
|
||||
)
|
||||
|
||||
model = ai_field.table.get_model()
|
||||
req_row_ids = row_ids
|
||||
rows = RowHandler().get_rows(model, req_row_ids)
|
||||
if len(rows) != len(req_row_ids):
|
||||
found_rows_ids = [row.id for row in rows]
|
||||
raise RowDoesNotExist(sorted(list(set(req_row_ids) - set(found_rows_ids))))
|
||||
|
||||
generative_ai_model_type = generative_ai_model_type_registry.get(
|
||||
ai_field.ai_generative_ai_type
|
||||
)
|
||||
ai_models = generative_ai_model_type.get_enabled_models()
|
||||
|
||||
if ai_field.ai_generative_ai_model not in ai_models:
|
||||
raise ModelDoesNotBelongToType(model_name=ai_field.ai_generative_ai_model)
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
context = HumanReadableRowContext(row, exclude_field_ids=[ai_field.id])
|
||||
message = str(
|
||||
resolve_formula(
|
||||
ai_field.ai_prompt, formula_runtime_function_registry, context
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
value = generative_ai_model_type.prompt(
|
||||
ai_field.ai_generative_ai_model, message
|
||||
)
|
||||
except Exception as exc:
|
||||
# If the prompt fails once, we should not continue with the other rows.
|
||||
rows_ai_values_generation_error.send(
|
||||
self,
|
||||
user=user,
|
||||
rows=rows[i:],
|
||||
field=ai_field,
|
||||
table=table,
|
||||
error_message=str(exc),
|
||||
)
|
||||
raise exc
|
||||
|
||||
RowHandler().update_row_by_id(
|
||||
user,
|
||||
table,
|
||||
row.id,
|
||||
{ai_field.db_column: value},
|
||||
model=model,
|
||||
values_already_prepared=True,
|
||||
)
|
||||
|
||||
|
||||
@baserow_trace(tracer)
|
||||
def _run_periodic_field_update(field, field_type_instance, all_updated_fields):
|
||||
add_baserow_trace_attrs(field_id=field.id)
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Generated by Django 4.1.13 on 2024-02-08 17:26
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import baserow.core.formula.field
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0153_passwordfield"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="AIField",
|
||||
fields=[
|
||||
(
|
||||
"field_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="database.field",
|
||||
),
|
||||
),
|
||||
("ai_generative_ai_type", models.CharField(max_length=32, null=True)),
|
||||
("ai_generative_ai_model", models.CharField(max_length=32, null=True)),
|
||||
("ai_prompt", baserow.core.formula.field.FormulaField(default="")),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
bases=("database.field",),
|
||||
),
|
||||
]
|
32
backend/src/baserow/contrib/database/rows/data_providers.py
Normal file
32
backend/src/baserow/contrib/database/rows/data_providers.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
from typing import List, Union
|
||||
|
||||
from baserow.contrib.builder.data_sources.builder_dispatch_context import (
|
||||
BuilderDispatchContext,
|
||||
)
|
||||
from baserow.core.formula.registries import DataProviderType
|
||||
|
||||
|
||||
class HumanReadableFieldsDataProviderType(DataProviderType):
|
||||
"""
|
||||
This data provider type is used to read the human readable values for the row
|
||||
fields. This is used for example in the AI field to be able to reference other
|
||||
fields in the same row to generate a different prompt for each row based on the
|
||||
values of the other fields.
|
||||
"""
|
||||
|
||||
type = "fields"
|
||||
|
||||
def get_data_chunk(
|
||||
self, dispatch_context: BuilderDispatchContext, path: List[str]
|
||||
) -> Union[int, str]:
|
||||
"""
|
||||
When a page parameter is read, returns the value previously saved from the
|
||||
request object.
|
||||
"""
|
||||
|
||||
if len(path) != 1:
|
||||
return None
|
||||
|
||||
first_part = path[0]
|
||||
|
||||
return dispatch_context.human_readable_row_values.get(first_part, "")
|
|
@ -28,7 +28,6 @@ from django.utils.encoding import force_str
|
|||
|
||||
from opentelemetry import metrics, trace
|
||||
|
||||
from baserow.contrib.database.api.rows.serializers import serialize_rows_for_response
|
||||
from baserow.contrib.database.fields.dependencies.handler import FieldDependencyHandler
|
||||
from baserow.contrib.database.fields.dependencies.update_collector import (
|
||||
FieldUpdateCollector,
|
||||
|
@ -276,7 +275,9 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
|
|||
prepared_values_by_field[
|
||||
field_name
|
||||
] = field_type.prepare_value_for_db_in_bulk(
|
||||
field["field"], batch_values, continue_on_error=generate_error_report
|
||||
field["field"],
|
||||
batch_values,
|
||||
continue_on_error=generate_error_report,
|
||||
)
|
||||
|
||||
# replace original values to keep ordering
|
||||
|
@ -964,8 +965,6 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
|
|||
updated_field_ids=updated_field_ids,
|
||||
)
|
||||
|
||||
before_rows_values = serialize_rows_for_response(rows, model)
|
||||
|
||||
if not values_already_prepared:
|
||||
prepared_values = self.prepare_values(model._field_objects, values)
|
||||
else:
|
||||
|
@ -1043,7 +1042,6 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
|
|||
model=model,
|
||||
before_return=before_return,
|
||||
updated_field_ids=updated_field_ids,
|
||||
before_rows_values=before_rows_values,
|
||||
m2m_change_tracker=m2m_change_tracker,
|
||||
)
|
||||
|
||||
|
@ -1636,8 +1634,6 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
|
|||
values["id"] = row.id
|
||||
original_row_values_by_id[row.id] = values
|
||||
|
||||
before_rows_values = serialize_rows_for_response(rows_to_update, model)
|
||||
|
||||
before_return = before_rows_update.send(
|
||||
self,
|
||||
rows=list(rows_to_update),
|
||||
|
@ -1815,7 +1811,6 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
|
|||
model=model,
|
||||
before_return=before_return,
|
||||
updated_field_ids=updated_field_ids,
|
||||
before_rows_values=before_rows_values,
|
||||
m2m_change_tracker=m2m_change_tracker,
|
||||
)
|
||||
|
||||
|
@ -1829,6 +1824,19 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
|
|||
fields_metadata_by_row_id,
|
||||
)
|
||||
|
||||
def get_rows(
|
||||
self, model: GeneratedTableModel, row_ids: List[int]
|
||||
) -> List[GeneratedTableModel]:
|
||||
"""
|
||||
Returns a list of rows based on the provided row ids.
|
||||
|
||||
:param model: The model that should be used to get the rows.
|
||||
:param row_ids: The list of row ids that should be fetched.
|
||||
:return: The list of rows.
|
||||
"""
|
||||
|
||||
return model.objects.filter(id__in=row_ids).enhance_by_fields()
|
||||
|
||||
def get_rows_for_update(
|
||||
self, model: GeneratedTableModel, row_ids: List[int]
|
||||
) -> RowsForUpdate:
|
||||
|
@ -1839,10 +1847,7 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
|
|||
"""
|
||||
|
||||
return cast(
|
||||
RowsForUpdate,
|
||||
model.objects.select_for_update(of=("self",))
|
||||
.enhance_by_fields()
|
||||
.filter(id__in=row_ids),
|
||||
RowsForUpdate, self.get_rows(model, row_ids).select_for_update(of=("self",))
|
||||
)
|
||||
|
||||
def move_row_by_id(
|
||||
|
@ -1906,7 +1911,6 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
|
|||
before_return = before_rows_update.send(
|
||||
self, rows=[row], user=user, table=table, model=model, updated_field_ids=[]
|
||||
)
|
||||
before_rows_values = serialize_rows_for_response([row], model)
|
||||
|
||||
row.order = self.get_unique_orders_before_row(before_row, model)[0]
|
||||
row.save()
|
||||
|
@ -1954,7 +1958,6 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
|
|||
model=model,
|
||||
before_return=before_return,
|
||||
updated_field_ids=[],
|
||||
before_rows_values=before_rows_values,
|
||||
prepared_rows_values=None,
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
from baserow.contrib.database.data_providers.registries import (
|
||||
database_data_provider_type_registry,
|
||||
)
|
||||
from baserow.core.formula.runtime_formula_context import RuntimeFormulaContext
|
||||
|
||||
|
||||
class HumanReadableRowContext(RuntimeFormulaContext):
|
||||
def __init__(self, row, exclude_field_ids=None):
|
||||
if exclude_field_ids is None:
|
||||
exclude_field_ids = []
|
||||
|
||||
model = row._meta.model
|
||||
self.human_readable_row_values = {
|
||||
f"field_{field['field'].id}": field["type"].get_human_readable_value(
|
||||
getattr(row, field["name"]), field
|
||||
)
|
||||
for field in model._field_objects.values()
|
||||
if field["field"].id not in exclude_field_ids
|
||||
}
|
||||
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def data_provider_registry(self):
|
||||
return database_data_provider_type_registry
|
|
@ -8,6 +8,7 @@ before_rows_delete = Signal()
|
|||
rows_created = Signal()
|
||||
rows_updated = Signal()
|
||||
rows_deleted = Signal()
|
||||
rows_ai_values_generation_error = Signal()
|
||||
|
||||
row_orders_recalculated = Signal()
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from baserow.contrib.database.api.rows.serializers import (
|
|||
serialize_rows_for_response,
|
||||
)
|
||||
from baserow.contrib.database.webhooks.registries import WebhookEventType
|
||||
from baserow.contrib.database.ws.rows.signals import serialize_rows_values
|
||||
|
||||
from .signals import rows_created, rows_deleted, rows_updated
|
||||
|
||||
|
@ -62,19 +63,11 @@ class RowsUpdatedEventType(RowsEventType):
|
|||
signal = rows_updated
|
||||
|
||||
def get_payload(
|
||||
self,
|
||||
event_id,
|
||||
webhook,
|
||||
model,
|
||||
table,
|
||||
rows,
|
||||
before_return,
|
||||
before_rows_values,
|
||||
**kwargs
|
||||
self, event_id, webhook, model, table, rows, before_return, **kwargs
|
||||
):
|
||||
payload = super().get_payload(event_id, webhook, model, table, rows, **kwargs)
|
||||
|
||||
old_items = before_rows_values
|
||||
old_items = dict(before_return)[serialize_rows_values]
|
||||
|
||||
if webhook.use_user_field_names:
|
||||
old_items = remap_serialized_rows_to_user_field_names(old_items, model)
|
||||
|
|
|
@ -11,7 +11,10 @@ from baserow.contrib.database.rows import signals as row_signals
|
|||
from baserow.contrib.database.table.models import GeneratedTableModel
|
||||
from baserow.contrib.database.views.handler import PublicViewRows, ViewHandler
|
||||
from baserow.contrib.database.views.registries import view_type_registry
|
||||
from baserow.contrib.database.ws.rows.signals import RealtimeRowMessages
|
||||
from baserow.contrib.database.ws.rows.signals import (
|
||||
RealtimeRowMessages,
|
||||
serialize_rows_values,
|
||||
)
|
||||
from baserow.core.telemetry.utils import baserow_trace
|
||||
from baserow.ws.registries import page_registry
|
||||
|
||||
|
@ -148,18 +151,10 @@ def public_before_rows_update(
|
|||
@receiver(row_signals.rows_updated)
|
||||
@baserow_trace(tracer)
|
||||
def public_rows_updated(
|
||||
sender,
|
||||
rows,
|
||||
user,
|
||||
table,
|
||||
model,
|
||||
before_return,
|
||||
updated_field_ids,
|
||||
before_rows_values,
|
||||
**kwargs
|
||||
sender, rows, user, table, model, before_return, updated_field_ids, **kwargs
|
||||
):
|
||||
before_return_dict = dict(before_return)[public_before_rows_update]
|
||||
serialized_old_rows = before_rows_values
|
||||
serialized_old_rows = dict(before_return)[serialize_rows_values]
|
||||
serialized_updated_rows = serialize_rows_for_response(rows, model)
|
||||
|
||||
old_row_public_views: List[PublicViewRows] = before_return_dict[
|
||||
|
|
|
@ -7,6 +7,7 @@ from baserow.contrib.database.api.rows.serializers import (
|
|||
RowHistorySerializer,
|
||||
RowSerializer,
|
||||
get_row_serializer_class,
|
||||
serialize_rows_for_response,
|
||||
)
|
||||
from baserow.contrib.database.rows import signals as row_signals
|
||||
from baserow.contrib.database.rows.registries import row_metadata_registry
|
||||
|
@ -14,6 +15,13 @@ from baserow.contrib.database.table.models import GeneratedTableModel
|
|||
from baserow.ws.registries import page_registry
|
||||
|
||||
|
||||
@receiver(row_signals.before_rows_update)
|
||||
def serialize_rows_values(
|
||||
sender, rows, user, table, model, updated_field_ids, **kwargs
|
||||
):
|
||||
return serialize_rows_for_response(rows, model)
|
||||
|
||||
|
||||
@receiver(row_signals.rows_created)
|
||||
def rows_created(
|
||||
sender,
|
||||
|
@ -57,10 +65,10 @@ def rows_updated(
|
|||
model,
|
||||
before_return,
|
||||
updated_field_ids,
|
||||
before_rows_values,
|
||||
**kwargs,
|
||||
):
|
||||
table_page_type = page_registry.get("table")
|
||||
before_rows_values = dict(before_return)[serialize_rows_values]
|
||||
transaction.on_commit(
|
||||
lambda: table_page_type.broadcast(
|
||||
RealtimeRowMessages.rows_updated(
|
||||
|
@ -79,6 +87,26 @@ def rows_updated(
|
|||
)
|
||||
|
||||
|
||||
@receiver(row_signals.rows_ai_values_generation_error)
|
||||
def rows_ai_values_generation_error(
|
||||
sender, user, rows, field, table, error_message, **kwargs
|
||||
):
|
||||
table_page_type = page_registry.get("table")
|
||||
transaction.on_commit(
|
||||
lambda: table_page_type.broadcast(
|
||||
{
|
||||
"type": "rows_ai_values_generation_error",
|
||||
"field_id": field.id,
|
||||
"table_id": table.id,
|
||||
"row_ids": [row.id for row in rows],
|
||||
"error": error_message,
|
||||
},
|
||||
getattr(user, "web_socket_id", None),
|
||||
table_id=table.id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@receiver(row_signals.before_rows_delete)
|
||||
def before_rows_delete(sender, rows, user, table, model, **kwargs):
|
||||
return get_row_serializer_class(model, RowSerializer, is_response=True)(
|
||||
|
|
|
@ -296,6 +296,17 @@ class CoreConfig(AppConfig):
|
|||
)
|
||||
notification_type_registry.register(BaserowVersionUpgradeNotificationType())
|
||||
|
||||
from baserow.core.generative_ai.generative_ai_models import (
|
||||
OllamaGenerativeAIModelType,
|
||||
OpenAIGenerativeAIModelType,
|
||||
)
|
||||
from baserow.core.generative_ai.registries import (
|
||||
generative_ai_model_type_registry,
|
||||
)
|
||||
|
||||
generative_ai_model_type_registry.register(OpenAIGenerativeAIModelType())
|
||||
generative_ai_model_type_registry.register(OllamaGenerativeAIModelType())
|
||||
|
||||
# Must import the Posthog signal, otherwise it won't work.
|
||||
import baserow.core.posthog # noqa: F403, F401
|
||||
|
||||
|
|
0
backend/src/baserow/core/generative_ai/__init__.py
Normal file
0
backend/src/baserow/core/generative_ai/__init__.py
Normal file
17
backend/src/baserow/core/generative_ai/exceptions.py
Normal file
17
backend/src/baserow/core/generative_ai/exceptions.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
from baserow.core.exceptions import InstanceTypeDoesNotExist
|
||||
|
||||
|
||||
class GenerativeAITypeDoesNotExist(InstanceTypeDoesNotExist):
|
||||
"""Raised when trying to get a generative AI type that does not exist."""
|
||||
|
||||
|
||||
class ModelDoesNotBelongToType(Exception):
|
||||
"""Raised when trying to get a model that does not belong to the type."""
|
||||
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
self.model_name = model_name
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class GenerativeAIPromptError(Exception):
|
||||
"""Raised when an error occurs while prompting the model."""
|
|
@ -0,0 +1,60 @@
|
|||
from django.conf import settings
|
||||
|
||||
from ollama import Client as OllamaClient
|
||||
from ollama import RequestError as OllamaRequestError
|
||||
from ollama import ResponseError as OllamaResponseError
|
||||
from openai import APIStatusError as OpenAIAPIStatusError
|
||||
from openai import OpenAI, OpenAIError
|
||||
|
||||
from baserow.core.generative_ai.exceptions import GenerativeAIPromptError
|
||||
|
||||
from .registries import GenerativeAIModelType
|
||||
|
||||
|
||||
class OpenAIGenerativeAIModelType(GenerativeAIModelType):
|
||||
type = "openai"
|
||||
|
||||
def is_enabled(self):
|
||||
return bool(settings.BASEROW_OPENAI_API_KEY) and bool(self.get_enabled_models())
|
||||
|
||||
def get_enabled_models(self):
|
||||
return settings.BASEROW_OPENAI_MODELS
|
||||
|
||||
def prompt(self, model, prompt):
|
||||
client = OpenAI(
|
||||
api_key=settings.BASEROW_OPENAI_API_KEY,
|
||||
organization=settings.BASEROW_OPENAI_ORGANIZATION,
|
||||
)
|
||||
try:
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
stream=False,
|
||||
)
|
||||
except (OpenAIError, OpenAIAPIStatusError) as exc:
|
||||
raise GenerativeAIPromptError(str(exc)) from exc
|
||||
return chat_completion.choices[0].message.content
|
||||
|
||||
|
||||
class OllamaGenerativeAIModelType(GenerativeAIModelType):
|
||||
type = "ollama"
|
||||
|
||||
def is_enabled(self):
|
||||
return bool(settings.BASEROW_OLLAMA_HOST) and bool(self.get_enabled_models())
|
||||
|
||||
def get_enabled_models(self):
|
||||
return settings.BASEROW_OLLAMA_MODELS
|
||||
|
||||
def prompt(self, model, prompt):
|
||||
client = OllamaClient(host=settings.BASEROW_OLLAMA_HOST)
|
||||
try:
|
||||
response = client.generate(model=model, prompt=prompt, stream=False)
|
||||
except (OllamaRequestError, OllamaResponseError) as exc:
|
||||
raise GenerativeAIPromptError(str(exc)) from exc
|
||||
|
||||
return response["response"]
|
31
backend/src/baserow/core/generative_ai/registries.py
Normal file
31
backend/src/baserow/core/generative_ai/registries.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
from baserow.core.registry import Instance, Registry
|
||||
|
||||
from .exceptions import GenerativeAITypeDoesNotExist
|
||||
|
||||
|
||||
class GenerativeAIModelType(Instance):
|
||||
def is_enabled(self):
|
||||
return False
|
||||
|
||||
def get_enabled_models(self):
|
||||
return []
|
||||
|
||||
def prompt(self, model, prompt):
|
||||
raise NotImplementedError("The prompt function must be implemented.")
|
||||
|
||||
|
||||
class GenerativeAIModelTypeRegistry(Registry):
|
||||
name = "generative_ai_model_type"
|
||||
does_not_exist_exception_class = GenerativeAITypeDoesNotExist
|
||||
|
||||
def get_models_per_type(self):
|
||||
return {
|
||||
key: value.get_enabled_models()
|
||||
for key, value in self.registry.items()
|
||||
if value.is_enabled()
|
||||
}
|
||||
|
||||
|
||||
generative_ai_model_type_registry: GenerativeAIModelTypeRegistry = (
|
||||
GenerativeAIModelTypeRegistry()
|
||||
)
|
|
@ -9,6 +9,7 @@ from .domain import DomainFixtures
|
|||
from .element import ElementFixtures
|
||||
from .field import FieldFixtures
|
||||
from .file_import import FileImportFixtures
|
||||
from .generative_ai import GenerativeAIFixtures
|
||||
from .integration import IntegrationFixtures
|
||||
from .job import JobFixtures
|
||||
from .notifications import NotificationsFixture
|
||||
|
@ -59,6 +60,7 @@ class Fixtures(
|
|||
UserSourceFixtures,
|
||||
AppAuthProviderFixtures,
|
||||
UserSourceUserFixtures,
|
||||
GenerativeAIFixtures,
|
||||
):
|
||||
def __init__(self, fake=None):
|
||||
self.fake = fake
|
||||
|
|
|
@ -5,6 +5,7 @@ from baserow.contrib.database.fields.dependencies.handler import FieldDependency
|
|||
from baserow.contrib.database.fields.field_cache import FieldCache
|
||||
from baserow.contrib.database.fields.field_types import AutonumberFieldType
|
||||
from baserow.contrib.database.fields.models import (
|
||||
AIField,
|
||||
AutonumberField,
|
||||
BooleanField,
|
||||
CreatedByField,
|
||||
|
@ -405,6 +406,29 @@ class FieldFixtures:
|
|||
self.set_test_field_kwarg_defaults(user, kwargs)
|
||||
|
||||
field = PasswordField.objects.create(**kwargs)
|
||||
if create_field:
|
||||
self.create_model_field(kwargs["table"], field)
|
||||
|
||||
return field
|
||||
|
||||
def create_ai_field(self, user=None, create_field=True, **kwargs):
|
||||
self.set_test_field_kwarg_defaults(user, kwargs)
|
||||
|
||||
# Register the fake generative AI model for testing purposes.
|
||||
self.register_fake_generate_ai_type()
|
||||
|
||||
if "ai_generative_ai_type" not in kwargs:
|
||||
kwargs["ai_generative_ai_type"] = "test_generative_ai"
|
||||
|
||||
if "ai_generative_ai_model" not in kwargs:
|
||||
kwargs["ai_generative_ai_model"] = "test_1"
|
||||
|
||||
if "ai_prompt" not in kwargs:
|
||||
kwargs[
|
||||
"ai_prompt"
|
||||
] = "'What is your purpose? Answer with a maximum of 10 words.'"
|
||||
|
||||
field = AIField.objects.create(**kwargs)
|
||||
|
||||
if create_field:
|
||||
self.create_model_field(kwargs["table"], field)
|
||||
|
|
65
backend/src/baserow/test_utils/fixtures/generative_ai.py
Normal file
65
backend/src/baserow/test_utils/fixtures/generative_ai.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
from baserow.core.generative_ai.exceptions import GenerativeAIPromptError
|
||||
|
||||
|
||||
class GenerativeAIFixtures:
|
||||
def register_fake_generate_ai_type(self, **kwargs):
|
||||
from baserow.core.generative_ai.registries import (
|
||||
GenerativeAIModelType,
|
||||
generative_ai_model_type_registry,
|
||||
)
|
||||
|
||||
class TestGenerativeAINoModelType(GenerativeAIModelType):
|
||||
type = "test_generative_ai_no_model"
|
||||
|
||||
def is_enabled(self):
|
||||
return True
|
||||
|
||||
def get_enabled_models(self):
|
||||
return []
|
||||
|
||||
def prompt(self, model, prompt):
|
||||
return ""
|
||||
|
||||
class TestGenerativeAIModelType(GenerativeAIModelType):
|
||||
type = "test_generative_ai"
|
||||
|
||||
def is_enabled(self):
|
||||
return True
|
||||
|
||||
def get_enabled_models(self):
|
||||
return ["test_1"]
|
||||
|
||||
def prompt(self, model, prompt):
|
||||
return f"Generated: {prompt}"
|
||||
|
||||
class TestGenerativeAIModelTypePromptError(GenerativeAIModelType):
|
||||
type = "test_generative_ai_prompt_error"
|
||||
|
||||
def is_enabled(self):
|
||||
return True
|
||||
|
||||
def get_enabled_models(self):
|
||||
return ["test_1"]
|
||||
|
||||
def prompt(self, model, prompt):
|
||||
raise GenerativeAIPromptError("Test error")
|
||||
|
||||
if (
|
||||
TestGenerativeAINoModelType.type
|
||||
not in generative_ai_model_type_registry.registry
|
||||
):
|
||||
generative_ai_model_type_registry.register(TestGenerativeAINoModelType())
|
||||
|
||||
if (
|
||||
TestGenerativeAIModelType.type
|
||||
not in generative_ai_model_type_registry.registry
|
||||
):
|
||||
generative_ai_model_type_registry.register(TestGenerativeAIModelType())
|
||||
|
||||
if (
|
||||
TestGenerativeAIModelTypePromptError.type
|
||||
not in generative_ai_model_type_registry.registry
|
||||
):
|
||||
generative_ai_model_type_registry.register(
|
||||
TestGenerativeAIModelTypePromptError()
|
||||
)
|
|
@ -96,6 +96,8 @@ def setup_interesting_test_table(
|
|||
if database is None:
|
||||
database = data_fixture.create_database_application(user=user)
|
||||
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
|
||||
file_suffix = file_suffix or ""
|
||||
|
||||
try:
|
||||
|
@ -241,6 +243,7 @@ def setup_interesting_test_table(
|
|||
],
|
||||
"phone_number": "+4412345678",
|
||||
"password": "test",
|
||||
"ai": "I'm an AI.",
|
||||
}
|
||||
|
||||
with freeze_time("2020-02-01 01:23"):
|
||||
|
|
|
@ -0,0 +1,268 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
from django.conf import settings
|
||||
from django.shortcuts import reverse
|
||||
|
||||
import pytest
|
||||
from rest_framework.status import (
|
||||
HTTP_202_ACCEPTED,
|
||||
HTTP_400_BAD_REQUEST,
|
||||
HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
from baserow.contrib.database.rows.handler import RowHandler
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_generate_ai_field_value_view_field_does_not_exist(data_fixture, api_client):
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
user, token = data_fixture.create_user_and_token(
|
||||
email="test@test.nl", password="password", first_name="Test1"
|
||||
)
|
||||
|
||||
database = data_fixture.create_database_application(user=user, name="database")
|
||||
table = data_fixture.create_database_table(name="table", database=database)
|
||||
field = data_fixture.create_ai_field(table=table, name="ai")
|
||||
|
||||
rows = RowHandler().create_rows(
|
||||
user,
|
||||
table,
|
||||
rows_values=[{}],
|
||||
)
|
||||
|
||||
response = api_client.post(
|
||||
reverse(
|
||||
"api:database:fields:async_generate_ai_field_values",
|
||||
kwargs={"field_id": 0},
|
||||
),
|
||||
{"row_ids": [rows[0].id]},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
assert response.status_code == HTTP_404_NOT_FOUND
|
||||
assert response.json()["error"] == "ERROR_FIELD_DOES_NOT_EXIST"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_generate_ai_field_value_view_row_does_not_exist(data_fixture, api_client):
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
user, token = data_fixture.create_user_and_token(
|
||||
email="test@test.nl", password="password", first_name="Test1"
|
||||
)
|
||||
|
||||
database = data_fixture.create_database_application(user=user, name="database")
|
||||
table = data_fixture.create_database_table(name="table", database=database)
|
||||
field = data_fixture.create_ai_field(table=table, name="ai")
|
||||
|
||||
rows = RowHandler().create_rows(
|
||||
user,
|
||||
table,
|
||||
rows_values=[{}],
|
||||
)
|
||||
|
||||
response = api_client.post(
|
||||
reverse(
|
||||
"api:database:fields:async_generate_ai_field_values",
|
||||
kwargs={"field_id": field.id},
|
||||
),
|
||||
{"row_ids": [0]},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
assert response.status_code == HTTP_404_NOT_FOUND
|
||||
assert response.json()["error"] == "ERROR_ROW_DOES_NOT_EXIST"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_generate_ai_field_value_view_user_not_in_workspace(data_fixture, api_client):
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
user, token = data_fixture.create_user_and_token(
|
||||
email="test@test.nl", password="password", first_name="Test1"
|
||||
)
|
||||
user_2, token_2 = data_fixture.create_user_and_token(
|
||||
email="test2@test.nl", password="password", first_name="Test1"
|
||||
)
|
||||
|
||||
database = data_fixture.create_database_application(user=user, name="database")
|
||||
table = data_fixture.create_database_table(name="table", database=database)
|
||||
field = data_fixture.create_ai_field(table=table, name="ai")
|
||||
|
||||
rows = RowHandler().create_rows(
|
||||
user,
|
||||
table,
|
||||
rows_values=[{}],
|
||||
)
|
||||
|
||||
response = api_client.post(
|
||||
reverse(
|
||||
"api:database:fields:async_generate_ai_field_values",
|
||||
kwargs={"field_id": field.id},
|
||||
),
|
||||
{"row_ids": [rows[0].id]},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token_2}",
|
||||
)
|
||||
assert response.status_code == HTTP_400_BAD_REQUEST
|
||||
assert response.json()["error"] == "ERROR_USER_NOT_IN_GROUP"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_generate_ai_field_value_view_generative_ai_does_not_exist(
|
||||
data_fixture, api_client
|
||||
):
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
user, token = data_fixture.create_user_and_token(
|
||||
email="test@test.nl", password="password", first_name="Test1"
|
||||
)
|
||||
|
||||
database = data_fixture.create_database_application(user=user, name="database")
|
||||
table = data_fixture.create_database_table(name="table", database=database)
|
||||
field = data_fixture.create_ai_field(
|
||||
table=table, name="ai", ai_generative_ai_type="does_not_exist"
|
||||
)
|
||||
|
||||
rows = RowHandler().create_rows(
|
||||
user,
|
||||
table,
|
||||
rows_values=[{}],
|
||||
)
|
||||
|
||||
response = api_client.post(
|
||||
reverse(
|
||||
"api:database:fields:async_generate_ai_field_values",
|
||||
kwargs={"field_id": field.id},
|
||||
),
|
||||
{"row_ids": [rows[0].id]},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
assert response.status_code == HTTP_400_BAD_REQUEST
|
||||
assert response.json()["error"] == "ERROR_GENERATIVE_AI_DOES_NOT_EXIST"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_generate_ai_field_value_view_generative_ai_model_does_not_belong_to_type(
|
||||
data_fixture, api_client
|
||||
):
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
user, token = data_fixture.create_user_and_token(
|
||||
email="test@test.nl", password="password", first_name="Test1"
|
||||
)
|
||||
|
||||
database = data_fixture.create_database_application(user=user, name="database")
|
||||
table = data_fixture.create_database_table(name="table", database=database)
|
||||
field = data_fixture.create_ai_field(
|
||||
table=table, name="ai", ai_generative_ai_model="does_not_exist"
|
||||
)
|
||||
|
||||
rows = RowHandler().create_rows(
|
||||
user,
|
||||
table,
|
||||
rows_values=[
|
||||
{},
|
||||
],
|
||||
)
|
||||
|
||||
response = api_client.post(
|
||||
reverse(
|
||||
"api:database:fields:async_generate_ai_field_values",
|
||||
kwargs={"field_id": field.id},
|
||||
),
|
||||
{"row_ids": [rows[0].id]},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
assert response.status_code == HTTP_400_BAD_REQUEST
|
||||
assert response.json()["error"] == "ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
@patch("baserow.contrib.database.fields.tasks.generate_ai_values_for_rows.apply")
|
||||
def test_generate_ai_field_value_view_generative_ai(
|
||||
patched_generate_ai_values_for_rows, data_fixture, api_client
|
||||
):
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
user, token = data_fixture.create_user_and_token(
|
||||
email="test@test.nl", password="password", first_name="Test1"
|
||||
)
|
||||
|
||||
database = data_fixture.create_database_application(user=user, name="database")
|
||||
table = data_fixture.create_database_table(name="table", database=database)
|
||||
field = data_fixture.create_ai_field(table=table, name="ai", ai_prompt="'Hello'")
|
||||
|
||||
rows = RowHandler().create_rows(
|
||||
user,
|
||||
table,
|
||||
rows_values=[{}],
|
||||
)
|
||||
assert patched_generate_ai_values_for_rows.call_count == 0
|
||||
|
||||
response = api_client.post(
|
||||
reverse(
|
||||
"api:database:fields:async_generate_ai_field_values",
|
||||
kwargs={"field_id": field.id},
|
||||
),
|
||||
{"row_ids": [rows[0].id]},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
assert response.status_code == HTTP_202_ACCEPTED
|
||||
assert patched_generate_ai_values_for_rows.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_batch_generate_ai_field_value_limit(api_client, data_fixture):
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
user, token = data_fixture.create_user_and_token()
|
||||
table = data_fixture.create_database_table(user=user)
|
||||
field = data_fixture.create_ai_field(table=table, name="ai", ai_prompt="'Hello'")
|
||||
rows = RowHandler().create_rows(
|
||||
user,
|
||||
table,
|
||||
rows_values=[{}] * (settings.BATCH_ROWS_SIZE_LIMIT + 1),
|
||||
)
|
||||
|
||||
row_ids = [row.id for row in rows]
|
||||
|
||||
# BATCH_ROWS_SIZE_LIMIT rows are allowed
|
||||
response = api_client.post(
|
||||
reverse(
|
||||
"api:database:fields:async_generate_ai_field_values",
|
||||
kwargs={"field_id": field.id},
|
||||
),
|
||||
{"row_ids": row_ids[: settings.BATCH_ROWS_SIZE_LIMIT]},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_202_ACCEPTED
|
||||
|
||||
# BATCH_ROWS_SIZE_LIMIT + 1 rows are not allowed
|
||||
response = api_client.post(
|
||||
reverse(
|
||||
"api:database:fields:async_generate_ai_field_values",
|
||||
kwargs={"field_id": field.id},
|
||||
),
|
||||
{"row_ids": row_ids},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_400_BAD_REQUEST
|
||||
assert response.json()["error"] == "ERROR_REQUEST_BODY_VALIDATION"
|
||||
assert response.json()["detail"] == {
|
||||
"row_ids": [
|
||||
{
|
||||
"code": "max_length",
|
||||
"error": f"Ensure this field has no more than"
|
||||
f" {settings.BATCH_ROWS_SIZE_LIMIT} elements.",
|
||||
},
|
||||
],
|
||||
}
|
|
@ -375,6 +375,7 @@ def test_get_row_serializer_with_user_field_names(data_fixture):
|
|||
"uuid": "00000000-0000-4000-8000-000000000003",
|
||||
"autonumber": 2,
|
||||
"password": True,
|
||||
"ai": "I'm an AI.",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
|
|
@ -233,12 +233,12 @@ def test_can_export_every_interesting_different_field_to_csv(
|
|||
"phone_number,formula_text,formula_int,formula_bool,formula_decimal,formula_dateinterval,"
|
||||
"formula_date,formula_singleselect,formula_email,formula_link_with_label,"
|
||||
"formula_link_url_only,formula_multipleselect,count,rollup,lookup,uuid,"
|
||||
"autonumber,password\r\n"
|
||||
"autonumber,password,ai\r\n"
|
||||
"1,,,,,,,,,0,False,,,,,,,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,"
|
||||
"02/01/2021 13:00,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,02/01/2021 13:00,"
|
||||
"user@example.com,user@example.com,,,,,,,,,,,,,,,,,,,test FORMULA,1,True,33.3333333333,"
|
||||
"1d 0:00,2020-01-01,,,label (https://google.com),https://google.com,,0,0.000,,"
|
||||
"00000000-0000-4000-8000-000000000002,1,\r\n"
|
||||
"00000000-0000-4000-8000-000000000002,1,,\r\n"
|
||||
"2,text,long_text,https://www.google.com,test@example.com,-1,1,-1.2,1.2,3,True,"
|
||||
"02/01/2020 01:23,02/01/2020,01/02/2020 01:23,01/02/2020,01/02/2020 02:23,"
|
||||
"01/02/2020 02:23,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,"
|
||||
|
@ -253,7 +253,7 @@ def test_can_export_every_interesting_different_field_to_csv(
|
|||
'"user2@example.com,user3@example.com",\'+4412345678,test FORMULA,1,True,33.3333333333,'
|
||||
"1d 0:00,2020-01-01,A,test@example.com,label (https://google.com),https://google.com,"
|
||||
'"D,C,E",3,-122.222,"linked_row_1,linked_row_2,",00000000-0000-4000-8000-000000000003,'
|
||||
"2,True\r\n"
|
||||
"2,True,I'm an AI.\r\n"
|
||||
)
|
||||
|
||||
assert contents == expected
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from baserow.contrib.database.fields.tasks import generate_ai_values_for_rows
|
||||
from baserow.contrib.database.rows.handler import RowHandler
|
||||
from baserow.core.generative_ai.exceptions import GenerativeAIPromptError
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
@patch("baserow.contrib.database.rows.signals.rows_updated.send")
|
||||
def test_generate_ai_field_value_view_generative_ai(patched_rows_updated, data_fixture):
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
user = data_fixture.create_user(
|
||||
email="test@test.nl", password="password", first_name="Test1"
|
||||
)
|
||||
|
||||
database = data_fixture.create_database_application(user=user, name="database")
|
||||
table = data_fixture.create_database_table(name="table", database=database)
|
||||
field = data_fixture.create_ai_field(table=table, name="ai", ai_prompt="'Hello'")
|
||||
|
||||
rows = RowHandler().create_rows(user, table, rows_values=[{}])
|
||||
|
||||
assert patched_rows_updated.call_count == 0
|
||||
generate_ai_values_for_rows(user.id, field.id, [rows[0].id])
|
||||
assert patched_rows_updated.call_count == 1
|
||||
updated_row = patched_rows_updated.call_args[1]["rows"][0]
|
||||
assert getattr(updated_row, field.db_column) == "Generated: Hello"
|
||||
assert patched_rows_updated.call_args[1]["updated_field_ids"] == set([field.id])
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
@patch("baserow.contrib.database.rows.signals.rows_updated.send")
|
||||
def test_generate_ai_field_value_view_generative_ai_parse_formula(
|
||||
patched_rows_updated, data_fixture
|
||||
):
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
user = data_fixture.create_user(
|
||||
email="test@test.nl", password="password", first_name="Test1"
|
||||
)
|
||||
|
||||
database = data_fixture.create_database_application(user=user, name="database")
|
||||
table = data_fixture.create_database_table(name="table", database=database)
|
||||
firstname = data_fixture.create_text_field(table=table, name="firstname")
|
||||
lastname = data_fixture.create_text_field(table=table, name="lastname")
|
||||
formula = f"concat('Hello ', get('fields.field_{firstname.id}'), ' ', get('fields.field_{lastname.id}'))"
|
||||
field = data_fixture.create_ai_field(table=table, name="ai", ai_prompt=formula)
|
||||
|
||||
rows = RowHandler().create_rows(
|
||||
user,
|
||||
table,
|
||||
rows_values=[
|
||||
{f"field_{firstname.id}": "Bram", f"field_{lastname.id}": "Wiepjes"},
|
||||
],
|
||||
)
|
||||
|
||||
assert patched_rows_updated.call_count == 0
|
||||
generate_ai_values_for_rows(user.id, field.id, [rows[0].id])
|
||||
assert patched_rows_updated.call_count == 1
|
||||
updated_row = patched_rows_updated.call_args[1]["rows"][0]
|
||||
assert getattr(updated_row, field.db_column) == "Generated: Hello Bram Wiepjes"
|
||||
assert patched_rows_updated.call_args[1]["updated_field_ids"] == set([field.id])
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
@patch("baserow.contrib.database.rows.signals.rows_updated.send")
|
||||
def test_generate_ai_field_value_view_generative_ai_invalid_field(
|
||||
patched_rows_updated, data_fixture
|
||||
):
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
user = data_fixture.create_user(
|
||||
email="test@test.nl", password="password", first_name="Test1"
|
||||
)
|
||||
|
||||
database = data_fixture.create_database_application(user=user, name="database")
|
||||
table = data_fixture.create_database_table(name="table", database=database)
|
||||
firstname = data_fixture.create_text_field(table=table, name="firstname")
|
||||
formula = "concat('Hello ', get('fields.field_0'))"
|
||||
field = data_fixture.create_ai_field(table=table, name="ai", ai_prompt=formula)
|
||||
|
||||
rows = RowHandler().create_rows(
|
||||
user,
|
||||
table,
|
||||
rows_values=[{f"field_{firstname.id}": "Bram"}],
|
||||
)
|
||||
assert patched_rows_updated.call_count == 0
|
||||
generate_ai_values_for_rows(user.id, field.id, [rows[0].id])
|
||||
assert patched_rows_updated.call_count == 1
|
||||
updated_row = patched_rows_updated.call_args[1]["rows"][0]
|
||||
assert getattr(updated_row, field.db_column) == "Generated: Hello "
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
@patch("baserow.contrib.database.rows.signals.rows_ai_values_generation_error.send")
|
||||
@patch("baserow.contrib.database.rows.signals.rows_updated.send")
|
||||
def test_generate_ai_field_value_view_generative_ai_invalid_prompt(
|
||||
patched_rows_updated, patched_rows_ai_values_generation_error, data_fixture
|
||||
):
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
user = data_fixture.create_user(
|
||||
email="test@test.nl", password="password", first_name="Test1"
|
||||
)
|
||||
|
||||
database = data_fixture.create_database_application(user=user, name="database")
|
||||
table = data_fixture.create_database_table(name="table", database=database)
|
||||
firstname = data_fixture.create_text_field(table=table, name="firstname")
|
||||
formula = "concat('Hello ', get('fields.field_0'))"
|
||||
field = data_fixture.create_ai_field(
|
||||
table=table,
|
||||
name="ai",
|
||||
ai_generative_ai_type="test_generative_ai_prompt_error",
|
||||
ai_prompt=formula,
|
||||
)
|
||||
|
||||
rows = RowHandler().create_rows(
|
||||
user,
|
||||
table,
|
||||
rows_values=[{f"field_{firstname.id}": "Bram"}],
|
||||
)
|
||||
|
||||
assert patched_rows_ai_values_generation_error.call_count == 0
|
||||
|
||||
with pytest.raises(GenerativeAIPromptError):
|
||||
generate_ai_values_for_rows(user.id, field.id, [rows[0].id])
|
||||
|
||||
assert patched_rows_updated.call_count == 0
|
||||
assert patched_rows_ai_values_generation_error.call_count == 1
|
||||
call_args_rows = patched_rows_ai_values_generation_error.call_args[1]["rows"]
|
||||
assert len(call_args_rows) == 1
|
||||
assert rows[0].id == call_args_rows[0].id
|
||||
assert patched_rows_ai_values_generation_error.call_args[1]["field"] == field
|
||||
assert (
|
||||
patched_rows_ai_values_generation_error.call_args[1]["error_message"]
|
||||
== "Test error"
|
||||
)
|
|
@ -0,0 +1,240 @@
|
|||
from django.shortcuts import reverse
|
||||
|
||||
import pytest
|
||||
from rest_framework.status import HTTP_200_OK, HTTP_400_BAD_REQUEST
|
||||
|
||||
from baserow.contrib.database.fields.handler import FieldHandler
|
||||
from baserow.contrib.database.fields.models import AIField
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_create_ai_field_type(data_fixture):
|
||||
user = data_fixture.create_user()
|
||||
table = data_fixture.create_database_table(user=user)
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
data_fixture.create_text_field(table=table, order=1, name="name")
|
||||
|
||||
handler = FieldHandler()
|
||||
ai_field = handler.create_field(
|
||||
user=user,
|
||||
table=table,
|
||||
type_name="ai",
|
||||
name="ai_1",
|
||||
ai_generative_ai_type="test_generative_ai",
|
||||
ai_generative_ai_model="test_1",
|
||||
ai_prompt="'Who are you?'",
|
||||
)
|
||||
|
||||
assert ai_field.ai_generative_ai_type == "test_generative_ai"
|
||||
assert ai_field.ai_generative_ai_model == "test_1"
|
||||
assert ai_field.ai_prompt == "'Who are you?'"
|
||||
assert len(AIField.objects.all()) == 1
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_update_ai_field_type(data_fixture):
|
||||
user = data_fixture.create_user()
|
||||
table = data_fixture.create_database_table(user=user)
|
||||
field = data_fixture.create_ai_field(table=table, order=1, name="name")
|
||||
|
||||
handler = FieldHandler()
|
||||
ai_field = handler.update_field(
|
||||
user=user,
|
||||
field=field,
|
||||
name="ai_1",
|
||||
ai_generative_ai_type="test_generative_ai",
|
||||
ai_generative_ai_model="test_1",
|
||||
ai_prompt="'Who are you?'",
|
||||
)
|
||||
|
||||
assert ai_field.ai_generative_ai_type == "test_generative_ai"
|
||||
assert ai_field.ai_generative_ai_model == "test_1"
|
||||
assert ai_field.ai_prompt == "'Who are you?'"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_delete_ai_field_type(data_fixture):
|
||||
user = data_fixture.create_user()
|
||||
table = data_fixture.create_database_table(user=user)
|
||||
field = data_fixture.create_ai_field(
|
||||
table=table,
|
||||
order=1,
|
||||
name="name",
|
||||
ai_generative_ai_type="test_generative_ai",
|
||||
ai_generative_ai_model="test_1",
|
||||
ai_prompt="'Who are you?'",
|
||||
)
|
||||
|
||||
handler = FieldHandler()
|
||||
handler.delete_field(user=user, field=field)
|
||||
|
||||
assert len(AIField.objects.all()) == 0
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_create_ai_field_type_via_api(data_fixture, api_client):
|
||||
user, token = data_fixture.create_user_and_token()
|
||||
table = data_fixture.create_database_table(user=user)
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
data_fixture.create_text_field(table=table, order=1, name="name")
|
||||
|
||||
response = api_client.post(
|
||||
reverse("api:database:fields:list", kwargs={"table_id": table.id}),
|
||||
{
|
||||
"name": "Test 1",
|
||||
"type": "ai",
|
||||
"ai_generative_ai_type": "test_generative_ai",
|
||||
"ai_generative_ai_model": "test_1",
|
||||
"ai_prompt": "'Who are you?'",
|
||||
},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
response_json = response.json()
|
||||
assert response.status_code == HTTP_200_OK
|
||||
assert response_json["ai_generative_ai_type"] == "test_generative_ai"
|
||||
assert response_json["ai_generative_ai_model"] == "test_1"
|
||||
assert response_json["ai_prompt"] == "'Who are you?'"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_create_ai_field_type_via_api_invalid_formula(data_fixture, api_client):
|
||||
user, token = data_fixture.create_user_and_token()
|
||||
table = data_fixture.create_database_table(user=user)
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
data_fixture.create_text_field(table=table, order=1, name="name")
|
||||
|
||||
response = api_client.post(
|
||||
reverse("api:database:fields:list", kwargs={"table_id": table.id}),
|
||||
{
|
||||
"name": "Test 1",
|
||||
"type": "ai",
|
||||
"ai_generative_ai_type": "test_generative_ai",
|
||||
"ai_generative_ai_model": "test_1",
|
||||
"ai_prompt": "ffff;;s9(",
|
||||
},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
assert response.status_code == HTTP_400_BAD_REQUEST
|
||||
response_json = response.json()
|
||||
assert response_json["error"] == "ERROR_REQUEST_BODY_VALIDATION"
|
||||
assert response_json["detail"]["ai_prompt"][0]["code"] == "invalid"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_create_ai_field_type_via_api_with_invalid_type(data_fixture, api_client):
|
||||
user, token = data_fixture.create_user_and_token()
|
||||
table = data_fixture.create_database_table(user=user)
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
data_fixture.create_text_field(table=table, order=1, name="name")
|
||||
|
||||
response = api_client.post(
|
||||
reverse("api:database:fields:list", kwargs={"table_id": table.id}),
|
||||
{
|
||||
"name": "Test 1",
|
||||
"type": "ai",
|
||||
"ai_generative_ai_type": "does_not_exist",
|
||||
"ai_generative_ai_model": "test_1",
|
||||
"ai_prompt": "'Who are you?'",
|
||||
},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
response_json = response.json()
|
||||
assert response.status_code == HTTP_400_BAD_REQUEST
|
||||
assert response_json["error"] == "ERROR_GENERATIVE_AI_DOES_NOT_EXIST"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_create_ai_field_type_via_api_with_invalid_model(data_fixture, api_client):
|
||||
user, token = data_fixture.create_user_and_token()
|
||||
table = data_fixture.create_database_table(user=user)
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
data_fixture.create_text_field(table=table, order=1, name="name")
|
||||
|
||||
response = api_client.post(
|
||||
reverse("api:database:fields:list", kwargs={"table_id": table.id}),
|
||||
{
|
||||
"name": "Test 1",
|
||||
"type": "ai",
|
||||
"ai_generative_ai_type": "test_generative_ai",
|
||||
"ai_generative_ai_model": "does_not_exist",
|
||||
"ai_prompt": "'Who are you?'",
|
||||
},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
response_json = response.json()
|
||||
assert response.status_code == HTTP_400_BAD_REQUEST
|
||||
assert response_json["error"] == "ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_update_ai_field_type_via_api_with_invalid_type(data_fixture, api_client):
|
||||
user, token = data_fixture.create_user_and_token()
|
||||
table = data_fixture.create_database_table(user=user)
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
field = data_fixture.create_ai_field(table=table, order=1, name="name")
|
||||
|
||||
response = api_client.patch(
|
||||
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
|
||||
{"ai_generative_ai_type": "does_not_exist"},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
response_json = response.json()
|
||||
assert response.status_code == HTTP_400_BAD_REQUEST
|
||||
assert response_json["error"] == "ERROR_GENERATIVE_AI_DOES_NOT_EXIST"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_update_ai_field_type_via_api_with_invalid_model(data_fixture, api_client):
|
||||
user, token = data_fixture.create_user_and_token()
|
||||
table = data_fixture.create_database_table(user=user)
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
field = data_fixture.create_ai_field(table=table, order=1, name="name")
|
||||
|
||||
response = api_client.patch(
|
||||
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
|
||||
{"ai_generative_ai_model": "does_not_exist"},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
response_json = response.json()
|
||||
assert response.status_code == HTTP_400_BAD_REQUEST
|
||||
assert response_json["error"] == "ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.field_ai
|
||||
def test_update_ai_field_type_via_api_with_valid_model(data_fixture, api_client):
|
||||
user, token = data_fixture.create_user_and_token()
|
||||
table = data_fixture.create_database_table(user=user)
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
field = data_fixture.create_ai_field(table=table, order=1, name="name")
|
||||
|
||||
response = api_client.patch(
|
||||
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
|
||||
{"ai_generative_ai_model": "test_1"},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
assert response.status_code == HTTP_200_OK
|
||||
|
||||
response = api_client.patch(
|
||||
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
|
||||
{"ai_generative_ai_type": "test_generative_ai"},
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||
)
|
||||
assert response.status_code == HTTP_200_OK
|
|
@ -26,6 +26,7 @@ from baserow.contrib.database.fields.field_helpers import (
|
|||
construct_all_possible_field_kwargs,
|
||||
)
|
||||
from baserow.contrib.database.fields.field_types import (
|
||||
AIFieldType,
|
||||
AutonumberFieldType,
|
||||
BooleanFieldType,
|
||||
CountFieldType,
|
||||
|
@ -363,6 +364,13 @@ def test_field_conversion_autonumber(data_fixture):
|
|||
_test_can_convert_between_fields(data_fixture, AutonumberFieldType.type)
|
||||
|
||||
|
||||
@pytest.mark.field_ai
|
||||
@pytest.mark.disabled_in_ci
|
||||
@pytest.mark.django_db
|
||||
def test_field_conversion_ai(data_fixture):
|
||||
_test_can_convert_between_fields(data_fixture, AIFieldType.type)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_get_field(data_fixture):
|
||||
user = data_fixture.create_user()
|
||||
|
|
|
@ -113,6 +113,7 @@ def test_fill_table_fields_with_add_all_fields(data_fixture):
|
|||
user = data_fixture.create_user()
|
||||
table = data_fixture.create_database_table(user=user)
|
||||
text_field = data_fixture.create_text_field(user=user, table=table)
|
||||
data_fixture.register_fake_generate_ai_type()
|
||||
|
||||
model = table.get_model()
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import pytest
|
||||
|
||||
from baserow.contrib.database.api.rows.serializers import serialize_rows_for_response
|
||||
from baserow.contrib.database.fields.handler import FieldHandler
|
||||
from baserow.contrib.database.rows.handler import RowHandler
|
||||
from baserow.contrib.database.webhooks.registries import webhook_event_type_registry
|
||||
from baserow.contrib.database.ws.rows.signals import serialize_rows_values
|
||||
|
||||
|
||||
@pytest.mark.django_db()
|
||||
|
@ -92,8 +92,11 @@ def test_rows_updated_event_type(data_fixture):
|
|||
row = model.objects.create(**{f"field_{text_field.id}": "Old Test value"})
|
||||
getattr(row, f"field_{link_row_field.id}").add(i1.id)
|
||||
|
||||
before_return = {}
|
||||
before_rows_values = serialize_rows_for_response([row], model)
|
||||
before_return = {
|
||||
serialize_rows_values: serialize_rows_values(
|
||||
None, [row], user, table, model, [text_field.id]
|
||||
)
|
||||
}
|
||||
|
||||
row = RowHandler().update_row_by_id(
|
||||
user=user,
|
||||
|
@ -116,7 +119,6 @@ def test_rows_updated_event_type(data_fixture):
|
|||
table=table,
|
||||
rows=[row],
|
||||
before_return=before_return,
|
||||
before_rows_values=before_rows_values,
|
||||
)
|
||||
assert payload == {
|
||||
"table_id": table.id,
|
||||
|
@ -151,7 +153,6 @@ def test_rows_updated_event_type(data_fixture):
|
|||
table=table,
|
||||
rows=[row],
|
||||
before_return=before_return,
|
||||
before_rows_values=before_rows_values,
|
||||
)
|
||||
assert payload == {
|
||||
"table_id": table.id,
|
||||
|
|
|
@ -1724,6 +1724,13 @@ def test_local_baserow_table_service_generate_schema_with_interesting_test_table
|
|||
"metadata": {},
|
||||
"type": "boolean",
|
||||
},
|
||||
field_db_column_by_name["ai"]: {
|
||||
"title": "ai",
|
||||
"default": None,
|
||||
"original_type": "ai",
|
||||
"metadata": {},
|
||||
"type": "string",
|
||||
},
|
||||
"id": {"metadata": {}, "type": "number", "title": "Id"},
|
||||
}
|
||||
|
||||
|
|
7
changelog/entries/unreleased/feature/ai_field.json
Normal file
7
changelog/entries/unreleased/feature/ai_field.json
Normal file
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "feature",
|
||||
"message": "Introduced a new AI field type",
|
||||
"issue_number": null,
|
||||
"bullet_points": [],
|
||||
"created_at": "2024-04-02"
|
||||
}
|
|
@ -164,6 +164,11 @@ x-backend-variables: &backend-variables
|
|||
BASEROW_BUILDER_DOMAINS:
|
||||
SENTRY_DSN:
|
||||
SENTRY_BACKEND_DSN:
|
||||
BASEROW_OPENAI_API_KEY:
|
||||
BASEROW_OPENAI_ORGANIZATION:
|
||||
BASEROW_OPENAI_MODELS:
|
||||
BASEROW_OLLAMA_HOST:
|
||||
BASEROW_OLLAMA_MODELS:
|
||||
|
||||
|
||||
services:
|
||||
|
|
|
@ -105,7 +105,8 @@ def test_can_export_every_interesting_different_field_to_json(
|
|||
"lookup": [],
|
||||
"uuid": "00000000-0000-4000-8000-000000000002",
|
||||
"autonumber": 1,
|
||||
"password": ""
|
||||
"password": "",
|
||||
"ai": ""
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
|
@ -221,7 +222,8 @@ def test_can_export_every_interesting_different_field_to_json(
|
|||
],
|
||||
"uuid": "00000000-0000-4000-8000-000000000003",
|
||||
"autonumber": 2,
|
||||
"password": true
|
||||
"password": true,
|
||||
"ai": "I'm an AI."
|
||||
}
|
||||
]
|
||||
"""
|
||||
|
@ -371,6 +373,7 @@ def test_can_export_every_interesting_different_field_to_xml(
|
|||
<uuid>00000000-0000-4000-8000-000000000002</uuid>
|
||||
<autonumber>1</autonumber>
|
||||
<password/>
|
||||
<ai/>
|
||||
</row>
|
||||
<row>
|
||||
<id>2</id>
|
||||
|
@ -487,6 +490,7 @@ def test_can_export_every_interesting_different_field_to_xml(
|
|||
<uuid>00000000-0000-4000-8000-000000000003</uuid>
|
||||
<autonumber>2</autonumber>
|
||||
<password>true</password>
|
||||
<ai>I'm an AI.</ai>
|
||||
</row>
|
||||
</rows>
|
||||
"""
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
:visible-fields="visibleCardFields"
|
||||
:hidden-fields="hiddenFields"
|
||||
:show-hidden-fields="showHiddenFieldsInRowModal"
|
||||
:all-fields-in-table="fields"
|
||||
@toggle-hidden-fields-visibility="
|
||||
showHiddenFieldsInRowModal = !showHiddenFieldsInRowModal
|
||||
"
|
||||
|
@ -35,7 +36,7 @@
|
|||
:database="database"
|
||||
:table="table"
|
||||
:view="view"
|
||||
:fields="fields"
|
||||
:all-fields-in-table="fields"
|
||||
:primary-is-sortable="false"
|
||||
:visible-fields="visibleCardFields"
|
||||
:hidden-fields="hiddenFields"
|
||||
|
|
|
@ -80,6 +80,7 @@
|
|||
:visible-fields="cardFields"
|
||||
:hidden-fields="hiddenFields"
|
||||
:show-hidden-fields="showHiddenFieldsInRowModal"
|
||||
:all-fields-in-table="fields"
|
||||
@toggle-hidden-fields-visibility="
|
||||
showHiddenFieldsInRowModal = !showHiddenFieldsInRowModal
|
||||
"
|
||||
|
@ -95,7 +96,7 @@
|
|||
:database="database"
|
||||
:table="table"
|
||||
:view="view"
|
||||
:fields="fields"
|
||||
:all-fields-in-table="fields"
|
||||
:primary-is-sortable="true"
|
||||
:visible-fields="cardFields"
|
||||
:hidden-fields="hiddenFields"
|
||||
|
|
|
@ -120,7 +120,8 @@
|
|||
"singleSelectDropdown": "Dropdown",
|
||||
"singleSelectRadios": "Radios",
|
||||
"autonumber": "Autonumber",
|
||||
"password": "Password"
|
||||
"password": "Password",
|
||||
"ai": "AI prompt"
|
||||
},
|
||||
"fieldErrors": {
|
||||
"invalidNumber": "Invalid number",
|
||||
|
@ -311,13 +312,16 @@
|
|||
"disabledPasswordProviderMessage": "Please use another authentication provider.",
|
||||
"maxLocksPerTransactionExceededTitle": "PostgreSQL issue detected",
|
||||
"maxLocksPerTransactionExceededDescription": "Baserow attempted to permanently delete the trashed items, but exceeded the available locks specified in `max_locks_per_transaction`.",
|
||||
"disabledPasswordProviderMessage": "Please use another authentication provider.",
|
||||
"lastAdminTitle": "Can't remove last workspace admin",
|
||||
"lastAdminMessage": "A workspace has to have at least one admin.",
|
||||
"adminAlreadyExistsTitle": "Can't use that username",
|
||||
"adminAlreadyExistsDescription": "That username can't be used because it already exists.",
|
||||
"cannotCreateFieldTypeTitle": "Can't create field",
|
||||
"cannotCreateFieldTypeDescription": "The requested field type cannot be created at the moment. It might be just a temporary problem with older tables. Please try again later."
|
||||
"cannotCreateFieldTypeDescription": "The requested field type cannot be created at the moment. It might be just a temporary problem with older tables. Please try again later.",
|
||||
"generativeAIDoesNotExistTitle": "Generative AI does not exist",
|
||||
"generativeAIDoesNotExistDescription": "The generative AI model does not exist.",
|
||||
"modelDoesNotBelongToTypeTitle": "The selected model does not belong to the AI Type",
|
||||
"modelDoesNotBelongToTypeDescription": "The selected model does not belong to the selected AI type."
|
||||
},
|
||||
"importerType": {
|
||||
"csv": "Import a CSV file",
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
:nodes="nodes"
|
||||
:node-selected="nodeSelected"
|
||||
:loading="dataExplorerLoading"
|
||||
:application-context="applicationContext"
|
||||
@node-selected="dataExplorerItemSelected"
|
||||
@node-toggled="editor.commands.focus()"
|
||||
@focusin="dataExplorerFocused = true"
|
||||
|
@ -52,7 +53,10 @@ export default {
|
|||
},
|
||||
provide() {
|
||||
// Provide the application context to all formula components
|
||||
return { applicationContext: this.applicationContext }
|
||||
return {
|
||||
applicationContext: this.applicationContext,
|
||||
dataProviders: this.dataProviders,
|
||||
}
|
||||
},
|
||||
props: {
|
||||
value: {
|
||||
|
|
|
@ -40,11 +40,16 @@ export default {
|
|||
NodeViewWrapper,
|
||||
},
|
||||
mixins: [formulaComponent],
|
||||
inject: ['applicationContext'],
|
||||
inject: ['applicationContext', 'dataProviders'],
|
||||
data() {
|
||||
return { nodes: [], pathParts: [] }
|
||||
},
|
||||
computed: {
|
||||
availableData() {
|
||||
return Object.values(this.dataProviders).map((dataProvider) =>
|
||||
dataProvider.getNodes(this.applicationContext)
|
||||
)
|
||||
},
|
||||
isInvalid() {
|
||||
return this.findNode(this.nodes, _.toPath(this.path)) === null
|
||||
},
|
||||
|
@ -58,7 +63,10 @@ export default {
|
|||
return _.toPath(this.path)
|
||||
},
|
||||
dataProviderType() {
|
||||
return this.$registry.get('builderDataProvider', this.rawPathParts[0])
|
||||
const pathParts = this.rawPathParts
|
||||
return this.dataProviders.find(
|
||||
(dataProvider) => dataProvider.type === pathParts[0]
|
||||
)
|
||||
},
|
||||
},
|
||||
mounted() {
|
||||
|
|
|
@ -147,6 +147,14 @@ export class ClientErrorMap {
|
|||
app.i18n.t('clientHandler.lastAdminTitle'),
|
||||
app.i18n.t('clientHandler.lastAdminMessage')
|
||||
),
|
||||
ERROR_GENERATIVE_AI_DOES_NOT_EXIST: new ResponseErrorMessage(
|
||||
app.i18n.t('clientHandler.generativeAIDoesNotExistTitle'),
|
||||
app.i18n.t('clientHandler.generativeAIDoesNotExistDescription')
|
||||
),
|
||||
ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE: new ResponseErrorMessage(
|
||||
app.i18n.t('clientHandler.modelDoesNotBelongToTypeTitle'),
|
||||
app.i18n.t('clientHandler.modelDoesNotBelongToTypeDescription')
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
:table="table"
|
||||
:view="view"
|
||||
:forced-type="singleSelectFieldType"
|
||||
:all-fields-in-table="fields"
|
||||
@field-created="$event.callback()"
|
||||
></CreateFieldContext>
|
||||
</div>
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
:table="table"
|
||||
:view="view"
|
||||
:forced-type="forcedType"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@submitted="submit"
|
||||
@keydown-enter="$refs.submitButton.focus()"
|
||||
>
|
||||
|
@ -57,6 +58,10 @@ export default {
|
|||
required: false,
|
||||
default: false,
|
||||
},
|
||||
allFieldsInTable: {
|
||||
type: Array,
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
|
|
|
@ -32,7 +32,6 @@
|
|||
</template>
|
||||
|
||||
<script>
|
||||
import { mapGetters } from 'vuex'
|
||||
import { notifyIf } from '@baserow/modules/core/utils/error'
|
||||
import { createNewUndoRedoActionGroupId } from '@baserow/modules/database/utils/action'
|
||||
import FieldService from '@baserow/modules/database/services/field'
|
||||
|
@ -52,6 +51,10 @@ export default {
|
|||
type: Object,
|
||||
required: true,
|
||||
},
|
||||
allFieldsInTable: {
|
||||
type: Array,
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
|
@ -62,14 +65,11 @@ export default {
|
|||
},
|
||||
computed: {
|
||||
existingFieldName() {
|
||||
return this.fields.map((field) => field.name)
|
||||
return this.allFieldsInTable.map((field) => field.name)
|
||||
},
|
||||
formFieldTypeIsReadOnly() {
|
||||
return this.$registry.get('field', this.fromField.type).isReadOnly
|
||||
},
|
||||
...mapGetters({
|
||||
fields: 'field/getAll',
|
||||
}),
|
||||
},
|
||||
methods: {
|
||||
onDuplicationEnd() {
|
||||
|
|
|
@ -0,0 +1,141 @@
|
|||
<template>
|
||||
<div>
|
||||
<div class="control">
|
||||
<label class="control__label control__label--small">{{
|
||||
$t('fieldAISubForm.AIType')
|
||||
}}</label>
|
||||
<div class="control__elements">
|
||||
<Dropdown
|
||||
v-model="values.ai_generative_ai_type"
|
||||
class="dropdown--floating"
|
||||
:class="{
|
||||
'dropdown--error': $v.values.ai_generative_ai_type.$error,
|
||||
}"
|
||||
:fixed-items="true"
|
||||
:show-search="false"
|
||||
small
|
||||
@hide="$v.values.ai_generative_ai_type.$touch()"
|
||||
@change="$refs.aiModel.select(aIModelsPerType[0])"
|
||||
>
|
||||
<DropdownItem
|
||||
v-for="aIType in aITypes"
|
||||
:key="aIType"
|
||||
:name="aIType"
|
||||
:value="aIType"
|
||||
/>
|
||||
</Dropdown>
|
||||
</div>
|
||||
</div>
|
||||
<div class="control">
|
||||
<label class="control__label control__label--small">
|
||||
{{ $t('fieldAISubForm.AIModel') }}
|
||||
</label>
|
||||
<div class="control__elements">
|
||||
<Dropdown
|
||||
ref="aiModel"
|
||||
v-model="values.ai_generative_ai_model"
|
||||
class="dropdown--floating"
|
||||
:class="{
|
||||
'dropdown--error': $v.values.ai_generative_ai_model.$error,
|
||||
}"
|
||||
:fixed-items="true"
|
||||
:show-search="false"
|
||||
small
|
||||
@hide="$v.values.ai_generative_ai_model.$touch()"
|
||||
>
|
||||
<DropdownItem
|
||||
v-for="aIType in aIModelsPerType"
|
||||
:key="aIType"
|
||||
:name="aIType"
|
||||
:value="aIType"
|
||||
/>
|
||||
</Dropdown>
|
||||
</div>
|
||||
</div>
|
||||
<div class="control">
|
||||
<label class="control__label control__label--small">
|
||||
{{ $t('fieldAISubForm.prompt') }}
|
||||
</label>
|
||||
<div class="control__elements">
|
||||
<div style="max-width: 366px">
|
||||
<FormulaInputField
|
||||
v-model="values.ai_prompt"
|
||||
:data-providers="dataProviders"
|
||||
:application-context="applicationContext"
|
||||
placeholder="What is Baserow?"
|
||||
></FormulaInputField>
|
||||
</div>
|
||||
<div
|
||||
v-if="$v.values.ai_prompt.$dirty && $v.values.ai_prompt.$error"
|
||||
class="error"
|
||||
>
|
||||
{{ $t('error.requiredField') }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import { mapGetters } from 'vuex'
|
||||
import { required } from 'vuelidate/lib/validators'
|
||||
|
||||
import form from '@baserow/modules/core/mixins/form'
|
||||
import fieldSubForm from '@baserow/modules/database/mixins/fieldSubForm'
|
||||
import FormulaInputField from '@baserow/modules/core/components/formula/FormulaInputField'
|
||||
|
||||
export default {
|
||||
name: 'FieldAISubForm',
|
||||
components: { FormulaInputField },
|
||||
mixins: [form, fieldSubForm],
|
||||
data() {
|
||||
return {
|
||||
allowedValues: [
|
||||
'ai_generative_ai_type',
|
||||
'ai_generative_ai_model',
|
||||
'ai_prompt',
|
||||
],
|
||||
values: {
|
||||
ai_generative_ai_type: null,
|
||||
ai_generative_ai_model: null,
|
||||
ai_prompt: '',
|
||||
},
|
||||
}
|
||||
},
|
||||
computed: {
|
||||
...mapGetters({
|
||||
settings: 'settings/get',
|
||||
}),
|
||||
applicationContext() {
|
||||
const context = {}
|
||||
Object.defineProperty(context, 'fields', {
|
||||
enumerable: true,
|
||||
get: () =>
|
||||
this.allFieldsInTable.filter((f) => {
|
||||
const isNotThisField = f.id !== this.defaultValues.id
|
||||
return isNotThisField
|
||||
}),
|
||||
})
|
||||
return context
|
||||
},
|
||||
dataProviders() {
|
||||
return [this.$registry.get('databaseDataProvider', 'fields')]
|
||||
},
|
||||
aITypes() {
|
||||
return Object.keys(this.settings.generative_ai)
|
||||
},
|
||||
aIModelsPerType() {
|
||||
return (
|
||||
this.settings.generative_ai[this.values.ai_generative_ai_type] || []
|
||||
)
|
||||
},
|
||||
},
|
||||
validations: {
|
||||
values: {
|
||||
ai_generative_ai_type: { required },
|
||||
ai_generative_ai_model: { required },
|
||||
ai_prompt: { required },
|
||||
},
|
||||
},
|
||||
}
|
||||
</script>
|
|
@ -36,6 +36,7 @@
|
|||
:table="table"
|
||||
:view="view"
|
||||
:field="field"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@update="$emit('update', $event)"
|
||||
@updated="$refs.context.hide()"
|
||||
></UpdateFieldContext>
|
||||
|
@ -94,6 +95,10 @@ export default {
|
|||
type: Object,
|
||||
required: true,
|
||||
},
|
||||
allFieldsInTable: {
|
||||
type: Array,
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
<template>
|
||||
<div>
|
||||
<FieldSelectThroughFieldSubForm
|
||||
:fields="fields"
|
||||
:fields="allFieldsInTable"
|
||||
:database="database"
|
||||
:default-values="defaultValues"
|
||||
></FieldSelectThroughFieldSubForm>
|
||||
|
@ -27,12 +27,6 @@ export default {
|
|||
database() {
|
||||
return this.$store.getters['application/get'](this.table.database_id)
|
||||
},
|
||||
fields() {
|
||||
// This part might fail in the future because we can't 100% depend on that the
|
||||
// fields in the store are related to the component that renders this. An example
|
||||
// is if you edit the field type in a row edit modal of a related table.
|
||||
return this.$store.getters['field/getAll']
|
||||
},
|
||||
},
|
||||
validations: {},
|
||||
}
|
||||
|
|
|
@ -60,7 +60,10 @@
|
|||
:icon="fieldType.iconClass"
|
||||
:name="fieldType.getName()"
|
||||
:value="fieldType.type"
|
||||
:disabled="primary && !fieldType.canBePrimaryField"
|
||||
:disabled="
|
||||
(primary && !fieldType.canBePrimaryField) ||
|
||||
!fieldType.isEnabled()
|
||||
"
|
||||
></DropdownItem>
|
||||
</Dropdown>
|
||||
<div v-if="$v.values.type.$error" class="error">
|
||||
|
@ -76,6 +79,7 @@
|
|||
:field-type="values.type"
|
||||
:view="view"
|
||||
:primary="primary"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
:name="values.name"
|
||||
:default-values="defaultValues"
|
||||
@validate="$v.$touch"
|
||||
|
@ -121,6 +125,10 @@ export default {
|
|||
required: false,
|
||||
default: null,
|
||||
},
|
||||
allFieldsInTable: {
|
||||
type: Array,
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
|
@ -201,7 +209,10 @@ export default {
|
|||
return !RESERVED_BASEROW_FIELD_NAMES.includes(param?.trim())
|
||||
},
|
||||
getFormComponent(type) {
|
||||
return this.$registry.get('field', type).getFormComponent()
|
||||
const fieldType = this.$registry.get('field', type)
|
||||
if (fieldType.isEnabled()) {
|
||||
return fieldType.getFormComponent()
|
||||
}
|
||||
},
|
||||
showFieldTypesDropdown(target) {
|
||||
this.$refs.fieldTypesDropdown.show(target)
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
:view="view"
|
||||
:loading="refreshingFormula"
|
||||
:formula-type-refresh-needed="formulaTypeRefreshNeeded"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@open-advanced-context="
|
||||
$refs.advancedFormulaEditContext.openContext($event)
|
||||
"
|
||||
|
@ -30,7 +31,6 @@
|
|||
|
||||
<script>
|
||||
import { required } from 'vuelidate/lib/validators'
|
||||
import { mapGetters } from 'vuex'
|
||||
|
||||
import form from '@baserow/modules/core/mixins/form'
|
||||
import { notifyIf } from '@baserow/modules/core/utils/error'
|
||||
|
@ -69,9 +69,6 @@ export default {
|
|||
}
|
||||
},
|
||||
computed: {
|
||||
...mapGetters({
|
||||
rawFields: 'field/getAll',
|
||||
}),
|
||||
localOrServerFormulaType() {
|
||||
return (
|
||||
this.mergedTypeOptions.array_formula_type ||
|
||||
|
@ -79,7 +76,7 @@ export default {
|
|||
)
|
||||
},
|
||||
fieldsUsableInFormula() {
|
||||
return this.rawFields.filter((f) => {
|
||||
return this.allFieldsInTable.filter((f) => {
|
||||
const isNotThisField = f.id !== this.defaultValues.id
|
||||
const canBeReferencedByFormulaField = this.$registry
|
||||
.get('field', f.type)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
<template>
|
||||
<div>
|
||||
<FieldSelectThroughFieldSubForm
|
||||
:fields="fields"
|
||||
:fields="allFieldsInTable"
|
||||
:database="database"
|
||||
:default-values="defaultValues"
|
||||
@input="selectedThroughField = $event"
|
||||
|
@ -20,6 +20,7 @@
|
|||
:formula-type="targetFieldFormulaType"
|
||||
:table="table"
|
||||
:view="view"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
>
|
||||
</FormulaTypeSubForms>
|
||||
</template>
|
||||
|
@ -27,8 +28,6 @@
|
|||
</template>
|
||||
|
||||
<script>
|
||||
import { mapGetters } from 'vuex'
|
||||
|
||||
import form from '@baserow/modules/core/mixins/form'
|
||||
import fieldSubForm from '@baserow/modules/database/mixins/fieldSubForm'
|
||||
import FormulaTypeSubForms from '@baserow/modules/database/components/formula/FormulaTypeSubForms'
|
||||
|
@ -66,12 +65,6 @@ export default {
|
|||
}
|
||||
return 'unknown'
|
||||
},
|
||||
...mapGetters({
|
||||
// This part might fail in the future because we can't 100% depend on that the
|
||||
// fields in the store are related to the component that renders this. An example
|
||||
// is if you edit the field type in a row edit modal of a related table.
|
||||
fields: 'field/getAll',
|
||||
}),
|
||||
},
|
||||
validations: {},
|
||||
methods: {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
<template>
|
||||
<div>
|
||||
<FieldSelectThroughFieldSubForm
|
||||
:fields="fields"
|
||||
:fields="allFieldsInTable"
|
||||
:database="database"
|
||||
:default-values="defaultValues"
|
||||
@input="selectedThroughField = $event"
|
||||
|
@ -43,6 +43,7 @@
|
|||
:formula-type="targetFieldFormulaType"
|
||||
:table="table"
|
||||
:view="view"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
>
|
||||
</FormulaTypeSubForms>
|
||||
</template>
|
||||
|
@ -53,7 +54,6 @@
|
|||
</template>
|
||||
|
||||
<script>
|
||||
import { mapGetters } from 'vuex'
|
||||
import { required } from 'vuelidate/lib/validators'
|
||||
|
||||
import form from '@baserow/modules/core/mixins/form'
|
||||
|
@ -99,12 +99,6 @@ export default {
|
|||
(f) => f.isRollupCompatible()
|
||||
)
|
||||
},
|
||||
...mapGetters({
|
||||
// This part might fail in the future because we can't 100% depend on that the
|
||||
// fields in the store are related to the component that renders this. An example
|
||||
// is if you edit the field type in a row edit modal of a related table.
|
||||
fields: 'field/getAll',
|
||||
}),
|
||||
},
|
||||
validations: {
|
||||
values: {
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
:view="view"
|
||||
:force-typed="forcedType"
|
||||
:use-action-group-id="true"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@field-created="$emit('field-created', $event)"
|
||||
@field-created-callback-done="updateInsertedFieldOrder"
|
||||
></CreateFieldContext>
|
||||
|
@ -35,6 +36,10 @@ export default {
|
|||
type: Object,
|
||||
required: true,
|
||||
},
|
||||
allFieldsInTable: {
|
||||
type: Array,
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
:view="view"
|
||||
:default-values="field"
|
||||
:primary="field.primary"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@submitted="submit"
|
||||
>
|
||||
<div
|
||||
|
@ -23,7 +24,7 @@
|
|||
type="submit"
|
||||
class="button"
|
||||
:class="{ 'button--loading': loading }"
|
||||
:disabled="loading"
|
||||
:disabled="loading || fieldTypeDisabled"
|
||||
>
|
||||
{{ $t('action.save') }}
|
||||
</button>
|
||||
|
@ -54,12 +55,21 @@ export default {
|
|||
type: Object,
|
||||
required: true,
|
||||
},
|
||||
allFieldsInTable: {
|
||||
type: Array,
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
loading: false,
|
||||
}
|
||||
},
|
||||
computed: {
|
||||
fieldTypeDisabled() {
|
||||
return !this.$registry.get('field', this.field.type).isEnabled()
|
||||
},
|
||||
},
|
||||
watch: {
|
||||
field() {
|
||||
// If the field values are updated via an outside source, think of real time
|
||||
|
|
|
@ -40,6 +40,7 @@
|
|||
:formula-type="formulaType"
|
||||
:table="table"
|
||||
:view="view"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
>
|
||||
</FormulaTypeSubForms>
|
||||
</template>
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
:table="table"
|
||||
:view="view"
|
||||
:allow-set-number-negative="false"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
>
|
||||
</FieldNumberSubForm>
|
||||
<FieldDateSubForm
|
||||
|
@ -13,6 +14,7 @@
|
|||
:default-values="defaultValues"
|
||||
:table="table"
|
||||
:view="view"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
>
|
||||
</FieldDateSubForm>
|
||||
<FieldDurationSubForm
|
||||
|
@ -20,6 +22,7 @@
|
|||
:default-values="defaultValues"
|
||||
:table="table"
|
||||
:view="view"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
>
|
||||
</FieldDurationSubForm>
|
||||
</div>
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
:database="database"
|
||||
:table="table"
|
||||
:rows="[]"
|
||||
:fields="fields"
|
||||
:all-fields-in-table="fields"
|
||||
:visible-fields="fields"
|
||||
:fields-sortable="fieldsSortable"
|
||||
:can-modify-fields="canModifyFields"
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
:table="table"
|
||||
:database="database"
|
||||
:can-modify-fields="canModifyFields"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@field-updated="$emit('field-updated', $event)"
|
||||
@field-deleted="$emit('field-deleted')"
|
||||
@order-fields="$emit('order-fields', $event)"
|
||||
|
@ -42,6 +43,7 @@
|
|||
:view="view"
|
||||
:database="database"
|
||||
:can-modify-fields="canModifyFields"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@field-updated="$emit('field-updated', $event)"
|
||||
@field-deleted="$emit('field-deleted')"
|
||||
@toggle-field-visibility="$emit('toggle-field-visibility', $event)"
|
||||
|
@ -122,6 +124,10 @@ export default {
|
|||
required: false,
|
||||
default: false,
|
||||
},
|
||||
allFieldsInTable: {
|
||||
type: Array,
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
<template>
|
||||
<div class="control__elements">
|
||||
<textarea
|
||||
ref="input"
|
||||
v-model="value"
|
||||
type="text"
|
||||
class="input field-long-text margin-bottom-2"
|
||||
:disabled="true"
|
||||
/>
|
||||
<a
|
||||
v-if="rowIsCreated"
|
||||
class="button button--ghost"
|
||||
:class="{ 'button--loading': generating }"
|
||||
@click="generate()"
|
||||
>{{ $t('rowEditFieldAI.generate') }}</a
|
||||
>
|
||||
<div v-else>{{ $t('rowEditFieldAI.createRowBefore') }}</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import rowEditField from '@baserow/modules/database/mixins/rowEditField'
|
||||
import fieldAI from '@baserow/modules/database/mixins/fieldAI'
|
||||
|
||||
export default {
|
||||
mixins: [rowEditField, fieldAI],
|
||||
}
|
||||
</script>
|
|
@ -53,6 +53,7 @@
|
|||
:view="view"
|
||||
:table="table"
|
||||
:database="database"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@field-updated="$emit('field-updated', $event)"
|
||||
@field-deleted="$emit('field-deleted')"
|
||||
@order-fields="$emit('order-fields', $event)"
|
||||
|
@ -77,6 +78,7 @@
|
|||
:view="view"
|
||||
:table="table"
|
||||
:database="database"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@field-updated="$emit('field-updated', $event)"
|
||||
@field-deleted="$emit('field-deleted')"
|
||||
@toggle-field-visibility="$emit('toggle-field-visibility', $event)"
|
||||
|
@ -109,6 +111,7 @@
|
|||
ref="createFieldContext"
|
||||
:table="table"
|
||||
:view="view"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@field-created="$emit('field-created', $event)"
|
||||
></CreateFieldContext>
|
||||
</div>
|
||||
|
@ -118,7 +121,7 @@
|
|||
:row="row"
|
||||
:table="table"
|
||||
:database="database"
|
||||
:fields="fields"
|
||||
:fields="allFieldsInTable"
|
||||
:read-only="readOnly"
|
||||
></RowEditModalSidebar>
|
||||
</template>
|
||||
|
@ -157,7 +160,7 @@ export default {
|
|||
required: false,
|
||||
default: null,
|
||||
},
|
||||
fields: {
|
||||
allFieldsInTable: {
|
||||
type: Array,
|
||||
required: true,
|
||||
},
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
:table="table"
|
||||
:view="view"
|
||||
:field="field"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@update="$emit('field-updated', $event)"
|
||||
@delete="$emit('field-deleted')"
|
||||
>
|
||||
|
@ -43,6 +44,7 @@
|
|||
:field="field"
|
||||
:value="row['field_' + field.id]"
|
||||
:read-only="readOnly"
|
||||
:row-is-created="!!row.id"
|
||||
@update="update"
|
||||
@refresh-row="$emit('refresh-row', row)"
|
||||
/>
|
||||
|
@ -100,6 +102,15 @@ export default {
|
|||
type: Boolean,
|
||||
required: true,
|
||||
},
|
||||
allFieldsInTable: {
|
||||
type: Array,
|
||||
required: true,
|
||||
},
|
||||
rowIsCreated: {
|
||||
type: Boolean,
|
||||
required: false,
|
||||
default: () => true,
|
||||
},
|
||||
},
|
||||
methods: {
|
||||
getFieldComponent(type) {
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
:database="database"
|
||||
:sortable="sortable && fieldIsSortable(field)"
|
||||
:can-modify-fields="canModifyFields"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@field-updated="$emit('field-updated', $event)"
|
||||
@field-deleted="$emit('field-deleted')"
|
||||
@update="$emit('update', $event)"
|
||||
|
@ -88,6 +89,10 @@ export default {
|
|||
type: Object,
|
||||
required: true,
|
||||
},
|
||||
allFieldsInTable: {
|
||||
type: Array,
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
methods: {
|
||||
fieldIsSortable(field) {
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
:database="database"
|
||||
:table="table"
|
||||
:sortable="false"
|
||||
:all-fields-in-table="allFields"
|
||||
:visible-fields="allFields"
|
||||
:can-modify-fields="false"
|
||||
@created="createRow"
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
:view="view"
|
||||
:fields="disabledFields"
|
||||
:enabled-fields="enabledFields"
|
||||
:all-fields-in-table="fields"
|
||||
:read-only="
|
||||
readOnly ||
|
||||
!$hasPermission(
|
||||
|
|
|
@ -68,6 +68,7 @@
|
|||
ref="createFieldContext"
|
||||
:table="table"
|
||||
:view="view"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@field-created="$event.callback()"
|
||||
></CreateFieldContext>
|
||||
</div>
|
||||
|
@ -123,6 +124,10 @@ export default {
|
|||
type: Boolean,
|
||||
required: true,
|
||||
},
|
||||
allFieldsInTable: {
|
||||
type: Array,
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
computed: {
|
||||
modeType() {
|
||||
|
|
|
@ -86,6 +86,7 @@
|
|||
:visible-fields="cardFields"
|
||||
:hidden-fields="hiddenFields"
|
||||
:show-hidden-fields="showHiddenFieldsInRowModal"
|
||||
:all-fields-in-table="fields"
|
||||
@toggle-hidden-fields-visibility="
|
||||
showHiddenFieldsInRowModal = !showHiddenFieldsInRowModal
|
||||
"
|
||||
|
@ -101,7 +102,7 @@
|
|||
:database="database"
|
||||
:table="table"
|
||||
:view="view"
|
||||
:fields="fields"
|
||||
:all-fields-in-table="fields"
|
||||
:primary-is-sortable="true"
|
||||
:visible-fields="cardFields"
|
||||
:hidden-fields="hiddenFields"
|
||||
|
|
|
@ -156,6 +156,15 @@
|
|||
:max-height-if-outside-viewport="true"
|
||||
>
|
||||
<ul v-show="isMultiSelectActive" class="context__menu">
|
||||
<component
|
||||
:is="contextItemComponent"
|
||||
v-for="(contextItemComponent, index) in getMultiSelectContextItems()"
|
||||
:key="index"
|
||||
:field="getSelectedField()"
|
||||
:rows="getSelectedRows()"
|
||||
:store-prefix="storePrefix"
|
||||
@click=";[$refs.rowContext.hide()]"
|
||||
></component>
|
||||
<li class="context__menu-item">
|
||||
<a
|
||||
class="context__menu-item-link"
|
||||
|
@ -297,7 +306,7 @@
|
|||
:database="database"
|
||||
:table="table"
|
||||
:view="view"
|
||||
:fields="fields"
|
||||
:all-fields-in-table="fields"
|
||||
:visible-fields="allVisibleFields"
|
||||
:hidden-fields="hiddenFields"
|
||||
:rows="allRows"
|
||||
|
@ -1177,6 +1186,7 @@ export default {
|
|||
rowId: row.id,
|
||||
fieldIndex,
|
||||
})
|
||||
this.$refs.rowContext.hide()
|
||||
},
|
||||
/**
|
||||
* Called when mouse hovers over a GridViewCell component.
|
||||
|
@ -1425,7 +1435,6 @@ export default {
|
|||
}
|
||||
|
||||
this.$store.dispatch('toast/setPasting', true)
|
||||
|
||||
try {
|
||||
await this.$store.dispatch(
|
||||
this.storePrefix + 'view/grid/updateDataIntoCells',
|
||||
|
@ -1531,6 +1540,38 @@ export default {
|
|||
height
|
||||
)
|
||||
},
|
||||
/**
|
||||
* Called when the user right clicks after selecting multiple cells.
|
||||
* Shows the context menu with the appropriate options.
|
||||
*/
|
||||
getMultiSelectContextItems() {
|
||||
const selectedFields = this.$store.getters[
|
||||
this.storePrefix + 'view/grid/getSelectedFields'
|
||||
](this.fields)
|
||||
|
||||
if (selectedFields.length === 1) {
|
||||
return this.$registry
|
||||
.get('field', selectedFields[0].type)
|
||||
.getGridViewContextItemsOnCellsSelection()
|
||||
} else {
|
||||
return []
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Returns the selected field if only one field is selected, otherwise returns null.
|
||||
*/
|
||||
getSelectedField() {
|
||||
const selectedFields = this.$store.getters[
|
||||
this.storePrefix + 'view/grid/getSelectedFields'
|
||||
](this.fields)
|
||||
return selectedFields.length === 1 ? selectedFields[0] : null
|
||||
},
|
||||
/**
|
||||
* Returns the selected rows if any rows are selected, otherwise returns an empty array.
|
||||
*/
|
||||
getSelectedRows() {
|
||||
return this.$store.getters[this.storePrefix + 'view/grid/getSelectedRows']
|
||||
},
|
||||
},
|
||||
}
|
||||
</script>
|
||||
|
|
|
@ -45,6 +45,7 @@
|
|||
:table="table"
|
||||
:view="view"
|
||||
:field="field"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@update="$emit('refresh', $event)"
|
||||
@delete="$emit('refresh')"
|
||||
>
|
||||
|
@ -98,6 +99,7 @@
|
|||
:table="table"
|
||||
:view="view"
|
||||
:from-field="field"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@field-created="$emit('field-created', $event)"
|
||||
@move-field="moveField($event)"
|
||||
></InsertFieldContext>
|
||||
|
@ -124,6 +126,7 @@
|
|||
ref="duplicateFieldModal"
|
||||
:table="table"
|
||||
:from-field="field"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@field-created="$emit('field-created', $event)"
|
||||
@move-field="moveField($event)"
|
||||
></DuplicateFieldModal>
|
||||
|
@ -288,6 +291,10 @@ export default {
|
|||
type: Boolean,
|
||||
required: true,
|
||||
},
|
||||
allFieldsInTable: {
|
||||
type: Array,
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
:table="table"
|
||||
:view="view"
|
||||
:field="field"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
:filters="view.filters"
|
||||
:include-field-width-handles="includeFieldWidthHandles"
|
||||
:read-only="readOnly"
|
||||
|
@ -71,6 +72,7 @@
|
|||
ref="createFieldContext"
|
||||
:table="table"
|
||||
:view="view"
|
||||
:all-fields-in-table="allFieldsInTable"
|
||||
@field-created="$emit('field-created', $event)"
|
||||
@shown="onShownCreateFieldContext"
|
||||
></CreateFieldContext>
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
<template functional>
|
||||
<div
|
||||
v-if="!props.value || $options.methods.isGenerating(parent, props)"
|
||||
class="grid-view__cell"
|
||||
>
|
||||
<div class="grid-field-button">
|
||||
<a
|
||||
class="button button--tiny button--ghost"
|
||||
:disabled="!$options.methods.isModelAvailable(parent, props)"
|
||||
:class="{
|
||||
'button--loading': $options.methods.isGenerating(parent, props),
|
||||
}"
|
||||
>
|
||||
<i18n path="functionalGridViewFieldAI.generate" tag="span" />
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
<div v-else class="grid-view__cell grid-field-long-text__cell">
|
||||
<div class="grid-field-long-text">{{ props.value }}</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import gridFieldAI from '@baserow/modules/database/mixins/gridFieldAI'
|
||||
|
||||
export default {
|
||||
name: 'FunctionalGridViewFieldAI',
|
||||
mixins: [gridFieldAI],
|
||||
}
|
||||
</script>
|
|
@ -0,0 +1,60 @@
|
|||
<template>
|
||||
<div
|
||||
v-if="!value || (!opened && generating)"
|
||||
ref="cell"
|
||||
class="grid-view__cell active"
|
||||
>
|
||||
<div class="grid-field-button">
|
||||
<button
|
||||
class="button button--tiny button--ghost"
|
||||
:disabled="!modelAvailable"
|
||||
:class="{ 'button--loading': generating }"
|
||||
@click="generate()"
|
||||
>
|
||||
{{ $t('gridViewFieldAI.generate') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
v-else
|
||||
ref="cell"
|
||||
class="grid-view__cell grid-field-long-text__cell active"
|
||||
:class="{ editing: opened }"
|
||||
@keyup.enter="opened = true"
|
||||
>
|
||||
<div v-if="!opened" class="grid-field-long-text">{{ value }}</div>
|
||||
<template v-else>
|
||||
<div class="grid-field-long-text__textarea">
|
||||
{{ value }}
|
||||
</div>
|
||||
<div style="background-color: #fff; padding: 8px">
|
||||
<button
|
||||
class="button button--link"
|
||||
:disabled="!modelAvailable"
|
||||
:class="{ 'button--loading': generating }"
|
||||
@click.prevent.stop="generate()"
|
||||
>
|
||||
<i class="button__icon iconoir-magic-wand"></i>
|
||||
{{ $t('gridViewFieldAI.regenerate') }}
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import gridField from '@baserow/modules/database/mixins/gridField'
|
||||
import gridFieldInput from '@baserow/modules/database/mixins/gridFieldInput'
|
||||
import gridFieldAI from '@baserow/modules/database/mixins/gridFieldAI'
|
||||
|
||||
export default {
|
||||
mixins: [gridField, gridFieldInput, gridFieldAI],
|
||||
methods: {
|
||||
save() {
|
||||
this.opened = false
|
||||
this.editing = false
|
||||
this.afterSave()
|
||||
},
|
||||
},
|
||||
}
|
||||
</script>
|
|
@ -0,0 +1,71 @@
|
|||
<template>
|
||||
<li class="context__menu-item">
|
||||
<a
|
||||
class="context__menu-item-link"
|
||||
:class="{ disabled: !modelAvailable }"
|
||||
@click.prevent.stop=";[generateAIFieldValues()]"
|
||||
>
|
||||
<i class="context__menu-item-icon iconoir-magic-wand"></i>
|
||||
{{ $t('gridView.generateCellsValues') }}
|
||||
</a>
|
||||
</li>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import FieldService from '@baserow/modules/database/services/field'
|
||||
import { notifyIf } from '@baserow/modules/core/utils/error'
|
||||
|
||||
export default {
|
||||
props: {
|
||||
field: {
|
||||
type: Object,
|
||||
required: true,
|
||||
},
|
||||
rows: {
|
||||
type: Array,
|
||||
required: true,
|
||||
},
|
||||
storePrefix: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
computed: {
|
||||
modelAvailable() {
|
||||
const aIModels =
|
||||
this.$store.getters['settings/get'].generative_ai[
|
||||
this.field.ai_generative_ai_type
|
||||
] || []
|
||||
return (
|
||||
this.$registry.get('field', this.field.type).isEnabled() &&
|
||||
aIModels.includes(this.field.ai_generative_ai_model)
|
||||
)
|
||||
},
|
||||
},
|
||||
methods: {
|
||||
async generateAIFieldValues($event) {
|
||||
if (!this.modelAvailable) {
|
||||
return
|
||||
}
|
||||
|
||||
const rowIds = this.rows.map((row) => row.id)
|
||||
const fieldId = this.field.id
|
||||
this.$store.dispatch(
|
||||
this.storePrefix + 'view/grid/setPendingFieldOperations',
|
||||
{ fieldId, rowIds, value: true }
|
||||
)
|
||||
|
||||
try {
|
||||
await FieldService(this.$client).generateAIFieldValues(fieldId, rowIds)
|
||||
} catch (error) {
|
||||
this.$store.dispatch(
|
||||
this.storePrefix + 'view/grid/setPendingFieldOperations',
|
||||
{ fieldId, rowIds, value: false }
|
||||
)
|
||||
notifyIf(error, 'field')
|
||||
}
|
||||
this.$emit('click', $event)
|
||||
},
|
||||
},
|
||||
}
|
||||
</script>
|
30
web-frontend/modules/database/dataProviderTypes.js
Normal file
30
web-frontend/modules/database/dataProviderTypes.js
Normal file
|
@ -0,0 +1,30 @@
|
|||
import { DataProviderType } from '@baserow/modules/core/dataProviderTypes'
|
||||
|
||||
export class FieldsDataProviderType extends DataProviderType {
|
||||
static getType() {
|
||||
return 'fields'
|
||||
}
|
||||
|
||||
get name() {
|
||||
return this.app.i18n.t('dataProviderTypes.fieldsName')
|
||||
}
|
||||
|
||||
getDataContent(applicationContext) {
|
||||
return ''
|
||||
}
|
||||
|
||||
getDataSchema(applicationContext) {
|
||||
return {
|
||||
type: 'object',
|
||||
properties: Object.fromEntries(
|
||||
(applicationContext.fields || []).map((field) => [
|
||||
`field_${field.id}`,
|
||||
{
|
||||
title: field.name,
|
||||
type: 'string',
|
||||
},
|
||||
])
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -30,6 +30,7 @@ import FieldLinkRowSubForm from '@baserow/modules/database/components/field/Fiel
|
|||
import FieldSelectOptionsSubForm from '@baserow/modules/database/components/field/FieldSelectOptionsSubForm'
|
||||
import FieldCollaboratorSubForm from '@baserow/modules/database/components/field/FieldCollaboratorSubForm'
|
||||
import FieldPasswordSubForm from '@baserow/modules/database/components/field/FieldPasswordSubForm'
|
||||
import FieldAISubForm from '@baserow/modules/database/components/field/FieldAISubForm'
|
||||
|
||||
import GridViewFieldText from '@baserow/modules/database/components/view/grid/fields/GridViewFieldText'
|
||||
import GridViewFieldLongText from '@baserow/modules/database/components/view/grid/fields/GridViewFieldLongText'
|
||||
|
@ -52,6 +53,7 @@ import GridViewFieldUUID from '@baserow/modules/database/components/view/grid/fi
|
|||
import GridViewFieldAutonumber from '@baserow/modules/database/components/view/grid/fields/GridViewFieldAutonumber'
|
||||
import GridViewFieldLastModifiedBy from '@baserow/modules/database/components/view/grid/fields/GridViewFieldLastModifiedBy'
|
||||
import GridViewFieldPassword from '@baserow/modules/database/components/view/grid/fields/GridViewFieldPassword'
|
||||
import GridViewFieldAI from '@baserow/modules/database/components/view/grid/fields/GridViewFieldAI'
|
||||
|
||||
import FunctionalGridViewFieldText from '@baserow/modules/database/components/view/grid/fields/FunctionalGridViewFieldText'
|
||||
import FunctionalGridViewFieldDuration from '@baserow/modules/database/components/view/grid/fields/FunctionalGridViewFieldDuration'
|
||||
|
@ -72,6 +74,7 @@ import FunctionalGridViewFieldUUID from '@baserow/modules/database/components/vi
|
|||
import FunctionalGridViewFieldAutonumber from '@baserow/modules/database/components/view/grid/fields/FunctionalGridViewFieldAutonumber'
|
||||
import FunctionalGridViewFieldLastModifiedBy from '@baserow/modules/database/components/view/grid/fields/FunctionalGridViewFieldLastModifiedBy'
|
||||
import FunctionalGridVIewFieldPassword from '@baserow/modules/database/components/view/grid/fields/FunctionalGridVIewFieldPassword.vue'
|
||||
import FunctionalGridViewFieldAI from '@baserow/modules/database/components/view/grid/fields/FunctionalGridViewFieldAI'
|
||||
|
||||
import RowEditFieldText from '@baserow/modules/database/components/row/RowEditFieldText'
|
||||
import RowEditFieldLongText from '@baserow/modules/database/components/row/RowEditFieldLongText'
|
||||
|
@ -94,6 +97,7 @@ import RowEditFieldUUID from '@baserow/modules/database/components/row/RowEditFi
|
|||
import RowEditFieldAutonumber from '@baserow/modules/database/components/row/RowEditFieldAutonumber'
|
||||
import RowEditFieldLastModifiedBy from '@baserow/modules/database/components/row/RowEditFieldLastModifiedBy'
|
||||
import RowEditFieldPassword from '@baserow/modules/database/components/row/RowEditFieldPassword'
|
||||
import RowEditFieldAI from '@baserow/modules/database/components/row/RowEditFieldAI'
|
||||
|
||||
import RowCardFieldBoolean from '@baserow/modules/database/components/card/RowCardFieldBoolean'
|
||||
import RowCardFieldDate from '@baserow/modules/database/components/card/RowCardFieldDate'
|
||||
|
@ -134,6 +138,8 @@ import FormViewFieldMultipleLinkRow from '@baserow/modules/database/components/v
|
|||
import FormViewFieldMultipleSelectCheckboxes from '@baserow/modules/database/components/view/form/FormViewFieldMultipleSelectCheckboxes'
|
||||
import FormViewFieldSingleSelectRadios from '@baserow/modules/database/components/view/form/FormViewFieldSingleSelectRadios'
|
||||
|
||||
import GridViewFieldAIGenerateValuesContextItem from '@baserow/modules/database/components/view/grid/fields/GridViewFieldAIGenerateValuesContextItem'
|
||||
|
||||
import { trueValues } from '@baserow/modules/core/utils/constants'
|
||||
import {
|
||||
getDateMomentFormat,
|
||||
|
@ -197,6 +203,16 @@ export class FieldType extends Registerable {
|
|||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* This method generates the context menu options for actions that can be performed on
|
||||
* more selected cells within the same field. These options appear in the grid view
|
||||
* when the user right-clicks on multiple cells.
|
||||
* @param field The field object.
|
||||
*/
|
||||
getGridViewContextItemsOnCellsSelection(field) {
|
||||
return []
|
||||
}
|
||||
|
||||
/**
|
||||
* This functional component should represent an unselect field cell related to the
|
||||
* value of this type. It will only be used in the grid view and is only for fast
|
||||
|
@ -756,6 +772,13 @@ export class FieldType extends Registerable {
|
|||
parseInputValue(field, value) {
|
||||
return value
|
||||
}
|
||||
|
||||
/**
|
||||
* Indicates whether it's possible to select the field type when creating or updating the field.
|
||||
*/
|
||||
isEnabled() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
export class TextFieldType extends FieldType {
|
||||
|
@ -4137,3 +4160,88 @@ export class PasswordFieldType extends FieldType {
|
|||
return RowHistoryFieldPassword
|
||||
}
|
||||
}
|
||||
|
||||
export class AIFieldType extends FieldType {
|
||||
static getType() {
|
||||
return 'ai'
|
||||
}
|
||||
|
||||
getIconClass() {
|
||||
return 'iconoir-magic-wand'
|
||||
}
|
||||
|
||||
getName() {
|
||||
const { i18n } = this.app
|
||||
return i18n.t('fieldType.ai')
|
||||
}
|
||||
|
||||
getGridViewFieldComponent() {
|
||||
return GridViewFieldAI
|
||||
}
|
||||
|
||||
getFunctionalGridViewFieldComponent() {
|
||||
return FunctionalGridViewFieldAI
|
||||
}
|
||||
|
||||
getRowEditFieldComponent(field) {
|
||||
return RowEditFieldAI
|
||||
}
|
||||
|
||||
getCardComponent() {
|
||||
return RowCardFieldText
|
||||
}
|
||||
|
||||
getRowHistoryEntryComponent() {
|
||||
return RowHistoryFieldText
|
||||
}
|
||||
|
||||
getFormComponent() {
|
||||
return FieldAISubForm
|
||||
}
|
||||
|
||||
getFormViewFieldComponents(field) {
|
||||
return {}
|
||||
}
|
||||
|
||||
getEmptyValue(field) {
|
||||
return null
|
||||
}
|
||||
|
||||
getSort(name, order) {
|
||||
return (a, b) => {
|
||||
const stringA = a[name] === null ? '' : '' + a[name]
|
||||
const stringB = b[name] === null ? '' : '' + b[name]
|
||||
|
||||
return collatedStringCompare(stringA, stringB, order)
|
||||
}
|
||||
}
|
||||
|
||||
getDocsDataType(field) {
|
||||
return 'string'
|
||||
}
|
||||
|
||||
getDocsDescription(field) {
|
||||
return '@TODO'
|
||||
}
|
||||
|
||||
getDocsRequestExample(field) {
|
||||
return 'string'
|
||||
}
|
||||
|
||||
getContainsFilterFunction() {
|
||||
return genericContainsFilter
|
||||
}
|
||||
|
||||
getContainsWordFilterFunction(field) {
|
||||
return genericContainsWordFilter
|
||||
}
|
||||
|
||||
getGridViewContextItemsOnCellsSelection(field) {
|
||||
return [GridViewFieldAIGenerateValuesContextItem]
|
||||
}
|
||||
|
||||
isEnabled() {
|
||||
const { store } = this.app
|
||||
return Object.keys(store.getters['settings/get'].generative_ai).length > 0
|
||||
}
|
||||
}
|
||||
|
|
|
@ -599,11 +599,14 @@
|
|||
"deleteRow": "Delete row",
|
||||
"deleteRows": "Delete rows",
|
||||
"copyCells": "Copy cells",
|
||||
"generateCellsValues": "Generate values with AI",
|
||||
"rowCount": "No rows | 1 row | {count} rows",
|
||||
"hiddenRowsInsertedTitle": "Rows added",
|
||||
"hiddenRowsInsertedMessage": "{number} newly added rows have been added, but are not visible because of the active filters.",
|
||||
"tooManyItemsTitle": "Too many items",
|
||||
"tooManyItemsDescription": "It's not possible to update more than {limit} rows at once, so we've only updated the first."
|
||||
"tooManyItemsDescription": "It's not possible to update more than {limit} rows at once, so we've only updated the first.",
|
||||
"AIValuesGenerationErrorTitle": "AI value generation failed",
|
||||
"AIValuesGenerationErrorMessage": "Please check your API_KEY and verify the selected model."
|
||||
},
|
||||
"gridViewFieldFile": {
|
||||
"dropHere": "Drop here",
|
||||
|
@ -829,5 +832,27 @@
|
|||
"passwordSet": "The password was set",
|
||||
"passwordUpdated": "The password was updated",
|
||||
"passwordDeleted": "The password was deleted"
|
||||
},
|
||||
"dataProviderTypes": {
|
||||
"fieldsName": "Fields"
|
||||
},
|
||||
"functionalGridViewFieldAI": {
|
||||
"generate": "Generate"
|
||||
},
|
||||
"gridViewFieldAI": {
|
||||
"generate": "Generate",
|
||||
"regenerate": "Re-Generate"
|
||||
},
|
||||
"fieldAISubForm": {
|
||||
"AIType": "AI Type",
|
||||
"AIModel": "AI Model",
|
||||
"prompt": "Prompt"
|
||||
},
|
||||
"rowEditFieldAI": {
|
||||
"generate": "Generate",
|
||||
"createRowBefore": "The AI value can be generated after the row has been created."
|
||||
},
|
||||
"rowCardFieldAI": {
|
||||
"generate": "Generate"
|
||||
}
|
||||
}
|
||||
|
|
40
web-frontend/modules/database/mixins/fieldAI.js
Normal file
40
web-frontend/modules/database/mixins/fieldAI.js
Normal file
|
@ -0,0 +1,40 @@
|
|||
import FieldService from '@baserow/modules/database/services/field'
|
||||
import { notifyIf } from '@baserow/modules/core/utils/error'
|
||||
|
||||
export default {
|
||||
data() {
|
||||
return {
|
||||
generating: false,
|
||||
}
|
||||
},
|
||||
computed: {
|
||||
modelAvailable() {
|
||||
const aIModels =
|
||||
this.$store.getters['settings/get'].generative_ai[
|
||||
this.field.ai_generative_ai_type
|
||||
] || []
|
||||
return (
|
||||
this.$registry.get('field', this.field.type).isEnabled() &&
|
||||
aIModels.includes(this.field.ai_generative_ai_model)
|
||||
)
|
||||
},
|
||||
},
|
||||
watch: {
|
||||
value() {
|
||||
this.generating = false
|
||||
},
|
||||
},
|
||||
methods: {
|
||||
async generate() {
|
||||
this.generating = true
|
||||
try {
|
||||
await FieldService(this.$client).generateAIFieldValues(this.field.id, [
|
||||
this.$parent.row.id,
|
||||
])
|
||||
} catch (error) {
|
||||
notifyIf(error, 'field')
|
||||
this.generating = false
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
|
@ -15,6 +15,9 @@ export default {
|
|||
},
|
||||
primary: {
|
||||
type: Boolean,
|
||||
},
|
||||
allFieldsInTable: {
|
||||
type: Array,
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
|
|
50
web-frontend/modules/database/mixins/gridFieldAI.js
Normal file
50
web-frontend/modules/database/mixins/gridFieldAI.js
Normal file
|
@ -0,0 +1,50 @@
|
|||
import FieldService from '@baserow/modules/database/services/field'
|
||||
import { notifyIf } from '@baserow/modules/core/utils/error'
|
||||
|
||||
export default {
|
||||
computed: {
|
||||
// Indicates if the cell is currently being generated together with other cells in
|
||||
// bulk via the apposite grid-view menu.
|
||||
generating() {
|
||||
return this.isGenerating(this.$parent, this.$props)
|
||||
},
|
||||
modelAvailable() {
|
||||
return this.isModelAvailable(this.$parent, this.$props)
|
||||
},
|
||||
},
|
||||
methods: {
|
||||
isGenerating(parent, props) {
|
||||
return parent.row._.pendingFieldOps?.find(
|
||||
(fieldName) => fieldName === `field_${props.field.id}`
|
||||
)
|
||||
},
|
||||
isModelAvailable(parent, props) {
|
||||
const aIModels =
|
||||
parent.$store.getters['settings/get'].generative_ai[
|
||||
props.field.ai_generative_ai_type
|
||||
] || []
|
||||
return (
|
||||
parent.$registry.get('field', props.field.type).isEnabled() &&
|
||||
aIModels.includes(props.field.ai_generative_ai_model)
|
||||
)
|
||||
},
|
||||
async generate() {
|
||||
const rowId = this.$parent.row.id
|
||||
this.$store.dispatch(
|
||||
this.storePrefix + 'view/grid/setPendingFieldOperations',
|
||||
{ fieldId: this.field.id, rowIds: [rowId], value: true }
|
||||
)
|
||||
try {
|
||||
await FieldService(this.$client).generateAIFieldValues(this.field.id, [
|
||||
rowId,
|
||||
])
|
||||
} catch (error) {
|
||||
notifyIf(error, 'field')
|
||||
this.$store.dispatch(
|
||||
this.storePrefix + 'view/grid/setPendingFieldOperations',
|
||||
{ fieldId: this.field.id, rowIds: [rowId], value: false }
|
||||
)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
|
@ -32,6 +32,11 @@ export default {
|
|||
required: false,
|
||||
default: true,
|
||||
},
|
||||
rowIsCreated: {
|
||||
type: Boolean,
|
||||
required: false,
|
||||
default: () => true,
|
||||
},
|
||||
},
|
||||
methods: {
|
||||
/**
|
||||
|
|
|
@ -32,6 +32,7 @@ import {
|
|||
UUIDFieldType,
|
||||
AutonumberFieldType,
|
||||
PasswordFieldType,
|
||||
AIFieldType,
|
||||
} from '@baserow/modules/database/fieldTypes'
|
||||
import {
|
||||
EqualViewFilterType,
|
||||
|
@ -257,6 +258,7 @@ import {
|
|||
FormSubmittedNotificationType,
|
||||
} from '@baserow/modules/database/notificationTypes'
|
||||
import { HistoryRowModalSidebarType } from '@baserow/modules/database/rowModalSidebarTypes'
|
||||
import { FieldsDataProviderType } from '@baserow/modules/database/dataProviderTypes'
|
||||
|
||||
import en from '@baserow/modules/database/locales/en.json'
|
||||
import fr from '@baserow/modules/database/locales/fr.json'
|
||||
|
@ -465,6 +467,7 @@ export default (context) => {
|
|||
app.$registry.register('field', new UUIDFieldType(context))
|
||||
app.$registry.register('field', new AutonumberFieldType(context))
|
||||
app.$registry.register('field', new PasswordFieldType(context))
|
||||
app.$registry.register('field', new AIFieldType(context))
|
||||
|
||||
app.$registry.register('importer', new CSVImporterType(context))
|
||||
app.$registry.register('importer', new PasteImporterType(context))
|
||||
|
@ -709,6 +712,11 @@ export default (context) => {
|
|||
|
||||
app.$registry.register('formViewMode', new FormViewFormModeType(context))
|
||||
|
||||
app.$registry.register(
|
||||
'databaseDataProvider',
|
||||
new FieldsDataProviderType(context)
|
||||
)
|
||||
|
||||
// notifications
|
||||
app.$registry.register(
|
||||
'notification',
|
||||
|
|
|
@ -214,6 +214,24 @@ export const registerRealtimeEvents = (realtime) => {
|
|||
}
|
||||
})
|
||||
|
||||
realtime.registerEvent(
|
||||
'rows_ai_values_generation_error',
|
||||
async (context, data) => {
|
||||
const { app } = context
|
||||
|
||||
for (const viewType of Object.values(app.$registry.getAll('view'))) {
|
||||
await viewType.AIValuesGenerationError(
|
||||
context,
|
||||
data.table_id,
|
||||
data.field_id,
|
||||
data.row_ids,
|
||||
data.error,
|
||||
'page/'
|
||||
)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
realtime.registerEvent('rows_deleted', (context, data) => {
|
||||
const { app, store } = context
|
||||
for (const viewType of Object.values(app.$registry.getAll('view'))) {
|
||||
|
|
|
@ -46,5 +46,11 @@ export default (client) => {
|
|||
config
|
||||
)
|
||||
},
|
||||
generateAIFieldValues(fieldId, rowIds) {
|
||||
return client.post(
|
||||
`/database/fields/${fieldId}/generate-ai-field-values/`,
|
||||
{ row_ids: rowIds }
|
||||
)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,6 +49,10 @@ export function populateRow(row, metadata = {}) {
|
|||
// between cells.
|
||||
selected: false,
|
||||
selectedFieldId: -1,
|
||||
// Contains the specific field ids that are in a loading state. This is for
|
||||
// example used for fields that use a background worker to compute the value
|
||||
// like the AI field.
|
||||
pendingFieldOps: [],
|
||||
}
|
||||
return row
|
||||
}
|
||||
|
@ -412,6 +416,17 @@ export const mutations = {
|
|||
if (metadata) {
|
||||
existingRowState._.metadata = metadata
|
||||
}
|
||||
|
||||
// Remove every pending AI field if a value is provided for it.
|
||||
if (existingRowState._?.pendingFieldOps?.length > 0) {
|
||||
const newFieldKeys = new Set(
|
||||
Object.keys(values).filter((key) => values[key])
|
||||
)
|
||||
existingRowState._.pendingFieldOps =
|
||||
existingRowState._.pendingFieldOps.filter(
|
||||
(key) => !newFieldKeys.has(key)
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
UPDATE_ROW_VALUES(state, { row, values }) {
|
||||
|
@ -586,6 +601,20 @@ export const mutations = {
|
|||
}
|
||||
})
|
||||
},
|
||||
SET_PENDING_FIELD_OPERATIONS(state, { fieldId, rowIds, value }) {
|
||||
const key = `field_${fieldId}`
|
||||
state.rows.forEach((row) => {
|
||||
if (rowIds.includes(row.id)) {
|
||||
if (value) {
|
||||
row._.pendingFieldOps.push(key)
|
||||
} else {
|
||||
row._.pendingFieldOps = row._.pendingFieldOps.filter(
|
||||
(fieldName) => fieldName !== key
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// Contains the info needed for the delayed scroll top action.
|
||||
|
@ -2260,6 +2289,7 @@ export const actions = {
|
|||
jsonData,
|
||||
rowIndex,
|
||||
fieldIndex,
|
||||
selectUpdatedCells = true,
|
||||
}
|
||||
) {
|
||||
const copiedRowsCount = textData.length
|
||||
|
@ -2324,7 +2354,7 @@ export const actions = {
|
|||
rowTailIndex = rowTailIndex + newRowsCount
|
||||
}
|
||||
|
||||
if (!isSingleCellCopied) {
|
||||
if (!isSingleCellCopied && selectUpdatedCells) {
|
||||
// Expand the selection of the multiple select to the cells that we're going to
|
||||
// paste in, so the user can see which values have been updated. This is because
|
||||
// it could be that there are more or less values in the clipboard compared to
|
||||
|
@ -2929,6 +2959,26 @@ export const actions = {
|
|||
fieldIndex: minFieldIndex,
|
||||
})
|
||||
},
|
||||
/**
|
||||
* Add the fieldId to the list of pending field operations for the given rowIds.
|
||||
* This is used to show a loading spinner when a field is being updated. For example,
|
||||
* the AI field type uses this to show a spinner when the AI values are being
|
||||
* generated in a background task.
|
||||
*/
|
||||
setPendingFieldOperations({ commit }, { fieldId, rowIds, value = true }) {
|
||||
commit('SET_PENDING_FIELD_OPERATIONS', { fieldId, rowIds, value })
|
||||
},
|
||||
AIValuesGenerationError({ commit, dispatch }, { fieldId, rowIds }) {
|
||||
commit('SET_PENDING_FIELD_OPERATIONS', { fieldId, rowIds, value: false })
|
||||
dispatch(
|
||||
'toast/error',
|
||||
{
|
||||
title: this.$i18n.t('gridView.AIValuesGenerationErrorTitle'),
|
||||
message: this.$i18n.t('gridView.AIValuesGenerationErrorMessage'),
|
||||
},
|
||||
{ root: true }
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
export const getters = {
|
||||
|
@ -3144,6 +3194,23 @@ export const getters = {
|
|||
)
|
||||
}
|
||||
},
|
||||
getSelectedFields: (state, getters) => (fields) => {
|
||||
const [minField, maxField] = getters.getMultiSelectFieldIndexSorted
|
||||
const selectedFields = []
|
||||
|
||||
const fieldMap = fields.reduce((acc, field) => {
|
||||
acc[field.id] = field
|
||||
return acc
|
||||
}, {})
|
||||
|
||||
for (let i = minField; i <= maxField; i++) {
|
||||
const fieldId = getters.getFieldIdByIndex(i, fields)
|
||||
if (fieldId !== -1) {
|
||||
selectedFields.push(fieldMap[fieldId])
|
||||
}
|
||||
}
|
||||
return selectedFields
|
||||
},
|
||||
getAllFieldAggregationData(state) {
|
||||
return state.fieldAggregationData
|
||||
},
|
||||
|
|
|
@ -259,6 +259,12 @@ export class ViewType extends Registerable {
|
|||
*/
|
||||
rowUpdated(context, tableId, fields, row, values, metadata, storePrefix) {}
|
||||
|
||||
/**
|
||||
* Event that is called when something went wrong while generating AI values
|
||||
* for a field. This can be used to show an error message to the user.
|
||||
*/
|
||||
AIValuesGenerationError(context, tableId, fieldId, rowIds, error) {}
|
||||
|
||||
/**
|
||||
* Event that is called when a row is deleted from an outside source, so for example
|
||||
* via a real time event by another user. It can be used to check if data in an store
|
||||
|
@ -630,6 +636,27 @@ export class GridViewType extends ViewType {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
AIValuesGenerationError(
|
||||
context,
|
||||
tableId,
|
||||
fieldId,
|
||||
rowIds,
|
||||
error,
|
||||
storePrefix = ''
|
||||
) {
|
||||
if (this.isCurrentView(context.store, tableId)) {
|
||||
context.store.dispatch(
|
||||
storePrefix + 'view/grid/AIValuesGenerationError',
|
||||
{
|
||||
fieldId,
|
||||
rowIds,
|
||||
error,
|
||||
},
|
||||
{ root: true }
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -333,6 +333,15 @@ const mockedFields = {
|
|||
type: 'password',
|
||||
testingRowData: [null, true, 'test'],
|
||||
},
|
||||
ai: {
|
||||
id: 26,
|
||||
name: 'ai',
|
||||
order: 26,
|
||||
primary: false,
|
||||
table_id: 42,
|
||||
type: 'ai',
|
||||
testingRowData: [null, 'Generated: hello!'],
|
||||
},
|
||||
}
|
||||
|
||||
const valuesToCall = [null, undefined]
|
||||
|
|
Loading…
Add table
Reference in a new issue