1
0
Fork 0
mirror of https://gitlab.com/bramw/baserow.git synced 2025-04-17 18:32:35 +00:00
This commit is contained in:
Bram Wiepjes 2024-04-02 19:50:44 +00:00
parent 7021acdb64
commit 1298aa7eec
98 changed files with 2393 additions and 108 deletions
backend
changelog/entries/unreleased/feature
docker-compose.yml
premium
backend/tests/baserow_premium_tests/export
web-frontend/modules/baserow_premium/components/views
web-frontend

View file

@ -21,7 +21,7 @@ format:
fix: sort format
sort:
isort --skip generated src tests ../premium/backend ../enterprise/backend || exit;
isort --skip generated --overwrite-in-place src tests ../premium/backend ../enterprise/backend || exit;
test:
pytest tests ../premium/backend/tests ../enterprise/backend/tests || exit;

View file

@ -37,6 +37,7 @@ markers =
field_multiple_collaborators: All tests related to multiple collaborator field
field_last_modified_by: All tests related to last modified by field
field_autonumber: All tests related to autonumber field
field_ai: All tests related to AI field
view_ownership: All tests related to view ownership type
view_calendar: All tests related to the calendar view
api_rows: All tests to manipulate rows via HTTP API

View file

@ -66,3 +66,6 @@ https://github.com/fellowapp/prosemirror-py/archive/refs/tags/v0.3.5.zip
rich==13.7.0
tzdata==2023.3
sentry-sdk==1.39.1
openai==1.9.0
typing_extensions==4.7.1
ollama==0.1.5

View file

@ -8,10 +8,15 @@ advocate==1.0.0
# via -r base.in
amqp==5.1.1
# via kombu
annotated-types==0.6.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via -r base.in
anyio==3.6.1
# via watchfiles
# via
# httpx
# openai
# watchfiles
asgiref==3.6.0
# via
# -r base.in
@ -66,6 +71,8 @@ celery-singleton==0.3.1
# via -r base.in
certifi==2023.7.22
# via
# httpcore
# httpx
# requests
# sentry-sdk
cffi==1.15.1
@ -110,6 +117,8 @@ deprecated==1.2.13
# via
# opentelemetry-api
# opentelemetry-exporter-otlp-proto-http
distro==1.9.0
# via openai
dj-database-url==1.3.0
# via -r base.in
django==4.1.13
@ -183,9 +192,17 @@ googleapis-common-protos==1.58.0
gunicorn==20.1.0
# via -r base.in
h11==0.14.0
# via uvicorn
# via
# httpcore
# uvicorn
httpcore==1.0.2
# via httpx
httptools==0.5.0
# via uvicorn
httpx==0.25.2
# via
# ollama
# openai
hyperlink==21.0.0
# via
# autobahn
@ -193,6 +210,7 @@ hyperlink==21.0.0
idna==3.4
# via
# anyio
# httpx
# hyperlink
# requests
# twisted
@ -230,6 +248,10 @@ netifaces==0.11.0
# via advocate
oauthlib==3.2.2
# via requests-oauthlib
ollama==0.1.5
# via -r base.in
openai==1.9.0
# via -r base.in
opentelemetry-api==1.21.0
# via
# -r base.in
@ -360,6 +382,10 @@ pyasn1-modules==0.2.8
# service-identity
pycparser==2.21
# via cffi
pydantic==2.5.3
# via openai
pydantic-core==2.14.6
# via pydantic
pygments==2.17.2
# via rich
pyjwt==2.5.0
@ -443,26 +469,35 @@ six==1.16.0
# python-dateutil
# service-identity
sniffio==1.3.0
# via anyio
# via
# anyio
# httpx
# openai
sqlparse==0.4.4
# via django
tenacity==8.1.0
# via celery-redbeat
tqdm==4.65.0
# via -r base.in
# via
# -r base.in
# openai
twisted[tls]==23.10.0
# via
# -r base.in
# daphne
txaio==23.1.1
# via autobahn
typing-extensions==4.4.0
typing-extensions==4.7.1
# via
# -r base.in
# azure-core
# azure-storage-blob
# dj-database-url
# openai
# opentelemetry-sdk
# prosemirror
# pydantic
# pydantic-core
# twisted
# uvicorn
tzdata==2023.3

View file

@ -345,7 +345,7 @@ types-requests==2.28.11.2
# via djangorestframework-stubs
types-urllib3==1.26.25
# via types-requests
typing-extensions==4.4.0
typing-extensions==4.7.1
# via
# -c base.txt
# black

View file

@ -0,0 +1,12 @@
from rest_framework.status import HTTP_400_BAD_REQUEST
ERROR_GENERATIVE_AI_DOES_NOT_EXIST = (
"ERROR_GENERATIVE_AI_DOES_NOT_EXIST",
HTTP_400_BAD_REQUEST,
"The requested generative AI does not exist.",
)
ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE = (
"ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE",
HTTP_400_BAD_REQUEST,
"The requested model does not belong to the provided type.",
)

View file

@ -1,6 +1,7 @@
from rest_framework import serializers
from baserow.api.user_files.serializers import UserFileField
from baserow.core.generative_ai.registries import generative_ai_model_type_registry
from baserow.core.models import Settings
@ -23,6 +24,7 @@ class SettingsSerializer(serializers.ModelSerializer):
required=False,
help_text="Co-branding logo that's placed next to the Baserow logo (176x29).",
)
generative_ai = serializers.SerializerMethodField()
class Meta:
model = Settings
@ -38,6 +40,7 @@ class SettingsSerializer(serializers.ModelSerializer):
"track_workspace_usage",
"show_baserow_help_request",
"co_branding_logo",
"generative_ai",
)
extra_kwargs = {
"allow_new_signups": {"required": False},
@ -49,6 +52,9 @@ class SettingsSerializer(serializers.ModelSerializer):
"show_baserow_help_request": {"required": False},
}
def get_generative_ai(self, object):
return generative_ai_model_type_registry.get_models_per_type()
class InstanceIdSerializer(serializers.ModelSerializer):
class Meta:

View file

@ -1206,3 +1206,15 @@ if SENTRY_DSN:
send_default_pii=False,
environment=os.getenv("SENTRY_ENVIRONMENT", ""),
)
BASEROW_OPENAI_API_KEY = os.getenv("BASEROW_OPENAI_API_KEY", None)
BASEROW_OPENAI_ORGANIZATION = os.getenv("BASEROW_OPENAI_ORGANIZATION", "") or None
BASEROW_OPENAI_MODELS = os.getenv("BASEROW_OPENAI_MODELS", "")
BASEROW_OPENAI_MODELS = (
BASEROW_OPENAI_MODELS.split(",") if BASEROW_OPENAI_MODELS else []
)
BASEROW_OLLAMA_HOST = os.getenv("BASEROW_OLLAMA_HOST", None)
BASEROW_OLLAMA_MODELS = os.getenv("BASEROW_OLLAMA_MODELS", "")
BASEROW_OLLAMA_MODELS = (
BASEROW_OLLAMA_MODELS.split(",") if BASEROW_OLLAMA_MODELS else []
)

View file

@ -1,6 +1,7 @@
from datetime import timedelta
from typing import Optional
from django.conf import settings
from django.contrib.auth.hashers import make_password
from django.core.exceptions import ValidationError
from django.utils.functional import lazy
@ -319,3 +320,14 @@ class PasswordSerializer(serializers.CharField):
return None
return make_password(data)
class GenerateAIFieldValueViewSerializer(serializers.Serializer):
row_ids = serializers.ListField(
child=serializers.IntegerField(),
max_length=settings.BATCH_ROWS_SIZE_LIMIT,
help_text="The ids of the rows that the values should be generated for.",
)
def to_internal_value(self, data):
return super().to_internal_value(data)

View file

@ -4,6 +4,7 @@ from baserow.contrib.database.fields.registries import field_type_registry
from .views import (
AsyncDuplicateFieldView,
AsyncGenerateAIFieldValuesView,
FieldsView,
FieldView,
UniqueRowValueFieldView,
@ -24,4 +25,9 @@ urlpatterns = field_type_registry.api_urls + [
AsyncDuplicateFieldView.as_view(),
name="async_duplicate",
),
re_path(
r"(?P<field_id>[0-9]+)/generate-ai-field-values/$",
AsyncGenerateAIFieldValuesView.as_view(),
name="async_generate_ai_field_values",
),
]

View file

@ -19,6 +19,10 @@ from baserow.api.decorators import (
validate_query_parameters,
)
from baserow.api.errors import ERROR_USER_NOT_IN_GROUP
from baserow.api.generative_ai.errors import (
ERROR_GENERATIVE_AI_DOES_NOT_EXIST,
ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE,
)
from baserow.api.jobs.errors import ERROR_MAX_JOB_COUNT_EXCEEDED
from baserow.api.jobs.serializers import JobSerializer
from baserow.api.schemas import (
@ -46,6 +50,7 @@ from baserow.contrib.database.api.fields.errors import (
ERROR_MAX_FIELD_COUNT_EXCEEDED,
ERROR_RESERVED_BASEROW_FIELD_NAME,
)
from baserow.contrib.database.api.rows.errors import ERROR_ROW_DOES_NOT_EXIST
from baserow.contrib.database.api.tables.errors import (
ERROR_FAILED_TO_LOCK_TABLE_DUE_TO_CONFLICT,
ERROR_TABLE_DOES_NOT_EXIST,
@ -75,13 +80,16 @@ from baserow.contrib.database.fields.exceptions import (
)
from baserow.contrib.database.fields.handler import FieldHandler
from baserow.contrib.database.fields.job_types import DuplicateFieldJobType
from baserow.contrib.database.fields.models import Field
from baserow.contrib.database.fields.models import AIField, Field
from baserow.contrib.database.fields.operations import (
CreateFieldOperationType,
ListFieldsOperationType,
ReadFieldOperationType,
)
from baserow.contrib.database.fields.registries import field_type_registry
from baserow.contrib.database.fields.tasks import generate_ai_values_for_rows
from baserow.contrib.database.rows.exceptions import RowDoesNotExist
from baserow.contrib.database.rows.handler import RowHandler
from baserow.contrib.database.table.exceptions import (
FailedToLockTableDueToConflict,
TableDoesNotExist,
@ -92,6 +100,11 @@ from baserow.contrib.database.tokens.handler import TokenHandler
from baserow.core.action.registries import action_type_registry
from baserow.core.db import specific_iterator
from baserow.core.exceptions import UserNotInWorkspace
from baserow.core.generative_ai.exceptions import (
GenerativeAITypeDoesNotExist,
ModelDoesNotBelongToType,
)
from baserow.core.generative_ai.registries import generative_ai_model_type_registry
from baserow.core.handler import CoreHandler
from baserow.core.jobs.exceptions import MaxJobCountExceeded
from baserow.core.jobs.handler import JobHandler
@ -103,6 +116,7 @@ from .serializers import (
DuplicateFieldParamsSerializer,
FieldSerializer,
FieldSerializerWithRelatedFields,
GenerateAIFieldValueViewSerializer,
RelatedFieldsSerializer,
UniqueRowValueParamsSerializer,
UniqueRowValuesSerializer,
@ -600,3 +614,87 @@ class AsyncDuplicateFieldView(APIView):
serializer = job_type_registry.get_serializer(job, JobSerializer)
return Response(serializer.data, status=status.HTTP_202_ACCEPTED)
class AsyncGenerateAIFieldValuesView(APIView):
permission_classes = (IsAuthenticated,)
@extend_schema(
parameters=[
OpenApiParameter(
name="field_id",
location=OpenApiParameter.PATH,
type=OpenApiTypes.INT,
description="The field to generate the value for.",
),
CLIENT_SESSION_ID_SCHEMA_PARAMETER,
CLIENT_UNDO_REDO_ACTION_GROUP_ID_SCHEMA_PARAMETER,
],
tags=["Database table fields"],
operation_id="generate_table_ai_field_value",
description=(
"Endpoint that's used by the AI field to start an sync task that "
"will update the cell value of the provided row IDs based on the "
"dynamically constructed prompt configured in the field settings."
),
request=None,
responses={
200: str,
400: get_error_schema(
[
"ERROR_GENERATIVE_AI_DOES_NOT_EXIST",
"ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE",
]
),
404: get_error_schema(
[
"ERROR_FIELD_DOES_NOT_EXIST",
"ERROR_ROW_DOES_NOT_EXIST",
]
),
},
)
@transaction.atomic
@map_exceptions(
{
FieldDoesNotExist: ERROR_FIELD_DOES_NOT_EXIST,
RowDoesNotExist: ERROR_ROW_DOES_NOT_EXIST,
UserNotInWorkspace: ERROR_USER_NOT_IN_GROUP,
GenerativeAITypeDoesNotExist: ERROR_GENERATIVE_AI_DOES_NOT_EXIST,
ModelDoesNotBelongToType: ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE,
}
)
@validate_body(GenerateAIFieldValueViewSerializer, return_validated=True)
def post(self, request: Request, field_id: int, data) -> Response:
ai_field = FieldHandler().get_field(
field_id,
base_queryset=AIField.objects.all().select_related(
"table__database__workspace"
),
)
CoreHandler().check_permissions(
request.user,
ListFieldsOperationType.type,
workspace=ai_field.table.database.workspace,
context=ai_field.table,
)
model = ai_field.table.get_model()
req_row_ids = data["row_ids"]
rows = RowHandler().get_rows(model, req_row_ids)
if len(rows) != len(req_row_ids):
found_rows_ids = [row.id for row in rows]
raise RowDoesNotExist(sorted(list(set(req_row_ids) - set(found_rows_ids))))
generative_ai_model_type = generative_ai_model_type_registry.get(
ai_field.ai_generative_ai_type
)
ai_models = generative_ai_model_type.get_enabled_models()
if ai_field.ai_generative_ai_model not in ai_models:
raise ModelDoesNotBelongToType(model_name=ai_field.ai_generative_ai_model)
generate_ai_values_for_rows.delay(request.user.id, ai_field.id, req_row_ids)
return Response(status=status.HTTP_202_ACCEPTED)

View file

@ -162,6 +162,7 @@ class DatabaseConfig(AppConfig):
plugin_registry.register(DatabasePlugin())
from .fields.field_types import (
AIFieldType,
AutonumberFieldType,
BooleanFieldType,
CountFieldType,
@ -216,8 +217,10 @@ class DatabaseConfig(AppConfig):
field_type_registry.register(UUIDFieldType())
field_type_registry.register(AutonumberFieldType())
field_type_registry.register(PasswordFieldType())
field_type_registry.register(AIFieldType())
from .fields.field_converters import (
AIFieldConverter,
AutonumberFieldConverter,
FileFieldConverter,
FormulaFieldConverter,
@ -244,6 +247,7 @@ class DatabaseConfig(AppConfig):
field_converter_registry.register(FormulaFieldConverter())
field_converter_registry.register(AutonumberFieldConverter())
field_converter_registry.register(PasswordFieldConverter())
field_converter_registry.register(AIFieldConverter())
from .fields.actions import (
CreateFieldActionType,
@ -726,6 +730,16 @@ class DatabaseConfig(AppConfig):
subject_type_registry.register(TokenSubjectType())
from baserow.contrib.database.data_providers.registries import (
database_data_provider_type_registry,
)
from .rows.data_providers import HumanReadableFieldsDataProviderType
database_data_provider_type_registry.register(
HumanReadableFieldsDataProviderType()
)
# notification_types
from baserow.contrib.database.fields.notification_types import (
CollaboratorAddedToRowNotificationType,

View file

@ -0,0 +1,3 @@
from baserow.core.formula.registries import DataProviderTypeRegistry
database_data_provider_type_registry = DataProviderTypeRegistry()

View file

@ -10,15 +10,18 @@ from baserow.contrib.database.db.schema import (
)
from .models import (
AIField,
AutonumberField,
FileField,
FormulaField,
LinkRowField,
LongTextField,
MultipleCollaboratorsField,
MultipleSelectField,
PasswordField,
SelectOption,
SingleSelectField,
TextField,
)
from .registries import FieldConverter, field_type_registry
@ -74,6 +77,16 @@ class PasswordFieldConverter(RecreateFieldConverter):
return to_password or from_password
class AIFieldConverter(RecreateFieldConverter):
type = "ai"
def is_applicable(self, from_model, from_field, to_field):
from_ai = isinstance(from_field, AIField)
to_ai = isinstance(to_field, AIField)
to_text_fields = isinstance(to_field, (TextField, LongTextField))
return from_ai and not (to_text_fields or to_ai) or not from_ai and to_ai
class LinkRowFieldConverter(RecreateFieldConverter):
type = "link_row"

View file

@ -235,6 +235,14 @@ def construct_all_possible_field_kwargs(
"uuid": [{"name": "uuid"}],
"autonumber": [{"name": "autonumber"}],
"password": [{"name": "password"}],
"ai": [
{
"name": "ai",
"ai_generative_ai_type": "test_generative_ai",
"ai_generative_ai_model": "test_1",
"ai_prompt": "Who are you?",
}
],
}
# If you have added a new field please add an entry into the dict above with any
# test worthy combinations of kwargs

View file

@ -40,6 +40,10 @@ from dateutil.parser import ParserError
from loguru import logger
from rest_framework import serializers
from baserow.api.generative_ai.errors import (
ERROR_GENERATIVE_AI_DOES_NOT_EXIST,
ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE,
)
from baserow.contrib.database.api.fields.errors import (
ERROR_DATE_FORCE_TIMEZONE_OFFSET_ERROR,
ERROR_INCOMPATIBLE_PRIMARY_FIELD_TYPE,
@ -96,6 +100,12 @@ from baserow.core.expressions import DateTrunc
from baserow.core.fields import SyncedDateTimeField
from baserow.core.formula import BaserowFormulaException
from baserow.core.formula.parser.exceptions import FormulaFunctionTypeDoesNotExist
from baserow.core.formula.serializers import FormulaSerializerField
from baserow.core.generative_ai.exceptions import (
GenerativeAITypeDoesNotExist,
ModelDoesNotBelongToType,
)
from baserow.core.generative_ai.registries import generative_ai_model_type_registry
from baserow.core.handler import CoreHandler
from baserow.core.models import UserFile, WorkspaceUser
from baserow.core.registries import ImportExportConfig
@ -155,6 +165,7 @@ from .fields import (
from .handler import FieldHandler
from .models import (
AbstractSelectOption,
AIField,
AutonumberField,
BooleanField,
CountField,
@ -5898,3 +5909,102 @@ class PasswordFieldType(FieldType):
# We don't want to expose the hash of the password, so we just show `True` or
# `False` as string depending on whether the value is set.
return bool(value)
class AIFieldType(CollationSortMixin, FieldType):
"""
The AI field can automatically query a generative AI model based on the provided
prompt. It's possible to reference other fields to generate a unique output.
"""
type = "ai"
model_class = AIField
can_be_in_form_view = False
keep_data_on_duplication = True
allowed_fields = ["ai_generative_ai_type", "ai_generative_ai_model", "ai_prompt"]
serializer_field_names = [
"ai_generative_ai_type",
"ai_generative_ai_model",
"ai_prompt",
]
serializer_field_overrides = {
"ai_prompt": FormulaSerializerField(
help_text="The prompt that must run for each row. Must be an formula.",
required=False,
allow_blank=True,
default="",
),
}
api_exceptions_map = {
GenerativeAITypeDoesNotExist: ERROR_GENERATIVE_AI_DOES_NOT_EXIST,
ModelDoesNotBelongToType: ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE,
}
can_get_unique_values = False
def get_serializer_field(self, instance, **kwargs):
required = kwargs.get("required", False)
return serializers.CharField(
**{
"required": required,
"allow_null": not required,
"allow_blank": not required,
**kwargs,
}
)
def get_model_field(self, instance, **kwargs):
return models.TextField(null=True, **kwargs)
def get_serializer_help_text(self, instance):
return (
"Holds a text value that is generated by a generative UI model using a "
"dynamic prompt."
)
def random_value(self, instance, fake, cache):
return fake.name()
def to_baserow_formula_type(self, field) -> BaserowFormulaType:
return BaserowFormulaTextType(nullable=True)
def from_baserow_formula_type(
self, formula_type: BaserowFormulaTextType
) -> TextField:
return TextField()
def get_value_for_filter(self, row: "GeneratedTableModel", field: Field) -> any:
value = getattr(row, field.db_column)
return collate_expression(Value(value))
def contains_query(self, *args):
return contains_filter(*args)
def contains_word_query(self, *args):
return contains_word_filter(*args)
def _validate_field_kwargs(self, ai_type, model_type):
ai_type = generative_ai_model_type_registry.get(ai_type)
models = ai_type.get_enabled_models()
if model_type not in models:
raise ModelDoesNotBelongToType(model_name=model_type)
def before_create(
self, table, primary, allowed_field_values, order, user, field_kwargs
):
ai_type = field_kwargs.get("ai_generative_ai_type", None)
model_type = field_kwargs.get("ai_generative_ai_model", None)
self._validate_field_kwargs(ai_type, model_type)
def before_update(self, from_field, to_field_values, user, field_kwargs):
update_field = None
if isinstance(from_field, AIField):
update_field = from_field
ai_type = field_kwargs.get("ai_generative_ai_type", None) or getattr(
update_field, "ai_generative_ai_type", None
)
model_type = field_kwargs.get("ai_generative_ai_model", None) or getattr(
update_field, "ai_generative_ai_model", None
)
self._validate_field_kwargs(ai_type, model_type)

View file

@ -26,6 +26,7 @@ from baserow.contrib.database.table.constants import (
MULTIPLE_SELECT_THROUGH_TABLE_PREFIX,
get_tsv_vector_field_name,
)
from baserow.core.formula.field import FormulaField as ModelFormulaField
from baserow.core.jobs.mixins import (
JobWithUndoRedoIds,
JobWithUserIpAddress,
@ -740,6 +741,12 @@ class PasswordField(Field):
pass
class AIField(Field):
ai_generative_ai_type = models.CharField(max_length=32, null=True)
ai_generative_ai_model = models.CharField(max_length=32, null=True)
ai_prompt = ModelFormulaField(default="")
class DuplicateFieldJob(
JobWithUserIpAddress, JobWithWebsocketId, JobWithUndoRedoIds, Job
):

View file

@ -177,7 +177,6 @@ def notify_users_added_to_row_when_rows_updated(
model,
before_return,
updated_field_ids,
before_rows_values,
m2m_change_tracker=None,
**kwargs
):

View file

@ -10,10 +10,25 @@ from loguru import logger
from opentelemetry import trace
from baserow.config.celery import app
from baserow.contrib.database.fields.handler import FieldHandler
from baserow.contrib.database.fields.models import AIField
from baserow.contrib.database.fields.operations import ListFieldsOperationType
from baserow.contrib.database.fields.registries import field_type_registry
from baserow.contrib.database.rows.exceptions import RowDoesNotExist
from baserow.contrib.database.rows.handler import RowHandler
from baserow.contrib.database.rows.runtime_formula_contexts import (
HumanReadableRowContext,
)
from baserow.contrib.database.rows.signals import rows_ai_values_generation_error
from baserow.contrib.database.search.handler import SearchHandler
from baserow.core.formula import resolve_formula
from baserow.core.formula.registries import formula_runtime_function_registry
from baserow.core.generative_ai.exceptions import ModelDoesNotBelongToType
from baserow.core.generative_ai.registries import generative_ai_model_type_registry
from baserow.core.handler import CoreHandler
from baserow.core.models import Workspace
from baserow.core.telemetry.utils import add_baserow_trace_attrs, baserow_trace
from baserow.core.user.handler import User
tracer = trace.get_tracer(__name__)
@ -105,6 +120,74 @@ def _run_periodic_field_type_update_per_workspace(
SearchHandler().entire_field_values_changed_or_created(fields[0].table, fields)
@app.task(bind=True, queue="export")
def generate_ai_values_for_rows(self, user_id: int, field_id: int, row_ids: list[int]):
user = User.objects.get(pk=user_id)
ai_field = FieldHandler().get_field(
field_id,
base_queryset=AIField.objects.all().select_related(
"table__database__workspace"
),
)
table = ai_field.table
CoreHandler().check_permissions(
user,
ListFieldsOperationType.type,
workspace=table.database.workspace,
context=table,
)
model = ai_field.table.get_model()
req_row_ids = row_ids
rows = RowHandler().get_rows(model, req_row_ids)
if len(rows) != len(req_row_ids):
found_rows_ids = [row.id for row in rows]
raise RowDoesNotExist(sorted(list(set(req_row_ids) - set(found_rows_ids))))
generative_ai_model_type = generative_ai_model_type_registry.get(
ai_field.ai_generative_ai_type
)
ai_models = generative_ai_model_type.get_enabled_models()
if ai_field.ai_generative_ai_model not in ai_models:
raise ModelDoesNotBelongToType(model_name=ai_field.ai_generative_ai_model)
for i, row in enumerate(rows):
context = HumanReadableRowContext(row, exclude_field_ids=[ai_field.id])
message = str(
resolve_formula(
ai_field.ai_prompt, formula_runtime_function_registry, context
)
)
try:
value = generative_ai_model_type.prompt(
ai_field.ai_generative_ai_model, message
)
except Exception as exc:
# If the prompt fails once, we should not continue with the other rows.
rows_ai_values_generation_error.send(
self,
user=user,
rows=rows[i:],
field=ai_field,
table=table,
error_message=str(exc),
)
raise exc
RowHandler().update_row_by_id(
user,
table,
row.id,
{ai_field.db_column: value},
model=model,
values_already_prepared=True,
)
@baserow_trace(tracer)
def _run_periodic_field_update(field, field_type_instance, all_updated_fields):
add_baserow_trace_attrs(field_id=field.id)

View file

@ -0,0 +1,38 @@
# Generated by Django 4.1.13 on 2024-02-08 17:26
import django.db.models.deletion
from django.db import migrations, models
import baserow.core.formula.field
class Migration(migrations.Migration):
dependencies = [
("database", "0153_passwordfield"),
]
operations = [
migrations.CreateModel(
name="AIField",
fields=[
(
"field_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="database.field",
),
),
("ai_generative_ai_type", models.CharField(max_length=32, null=True)),
("ai_generative_ai_model", models.CharField(max_length=32, null=True)),
("ai_prompt", baserow.core.formula.field.FormulaField(default="")),
],
options={
"abstract": False,
},
bases=("database.field",),
),
]

View file

@ -0,0 +1,32 @@
from typing import List, Union
from baserow.contrib.builder.data_sources.builder_dispatch_context import (
BuilderDispatchContext,
)
from baserow.core.formula.registries import DataProviderType
class HumanReadableFieldsDataProviderType(DataProviderType):
"""
This data provider type is used to read the human readable values for the row
fields. This is used for example in the AI field to be able to reference other
fields in the same row to generate a different prompt for each row based on the
values of the other fields.
"""
type = "fields"
def get_data_chunk(
self, dispatch_context: BuilderDispatchContext, path: List[str]
) -> Union[int, str]:
"""
When a page parameter is read, returns the value previously saved from the
request object.
"""
if len(path) != 1:
return None
first_part = path[0]
return dispatch_context.human_readable_row_values.get(first_part, "")

View file

@ -28,7 +28,6 @@ from django.utils.encoding import force_str
from opentelemetry import metrics, trace
from baserow.contrib.database.api.rows.serializers import serialize_rows_for_response
from baserow.contrib.database.fields.dependencies.handler import FieldDependencyHandler
from baserow.contrib.database.fields.dependencies.update_collector import (
FieldUpdateCollector,
@ -276,7 +275,9 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
prepared_values_by_field[
field_name
] = field_type.prepare_value_for_db_in_bulk(
field["field"], batch_values, continue_on_error=generate_error_report
field["field"],
batch_values,
continue_on_error=generate_error_report,
)
# replace original values to keep ordering
@ -964,8 +965,6 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
updated_field_ids=updated_field_ids,
)
before_rows_values = serialize_rows_for_response(rows, model)
if not values_already_prepared:
prepared_values = self.prepare_values(model._field_objects, values)
else:
@ -1043,7 +1042,6 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
model=model,
before_return=before_return,
updated_field_ids=updated_field_ids,
before_rows_values=before_rows_values,
m2m_change_tracker=m2m_change_tracker,
)
@ -1636,8 +1634,6 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
values["id"] = row.id
original_row_values_by_id[row.id] = values
before_rows_values = serialize_rows_for_response(rows_to_update, model)
before_return = before_rows_update.send(
self,
rows=list(rows_to_update),
@ -1815,7 +1811,6 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
model=model,
before_return=before_return,
updated_field_ids=updated_field_ids,
before_rows_values=before_rows_values,
m2m_change_tracker=m2m_change_tracker,
)
@ -1829,6 +1824,19 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
fields_metadata_by_row_id,
)
def get_rows(
self, model: GeneratedTableModel, row_ids: List[int]
) -> List[GeneratedTableModel]:
"""
Returns a list of rows based on the provided row ids.
:param model: The model that should be used to get the rows.
:param row_ids: The list of row ids that should be fetched.
:return: The list of rows.
"""
return model.objects.filter(id__in=row_ids).enhance_by_fields()
def get_rows_for_update(
self, model: GeneratedTableModel, row_ids: List[int]
) -> RowsForUpdate:
@ -1839,10 +1847,7 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
"""
return cast(
RowsForUpdate,
model.objects.select_for_update(of=("self",))
.enhance_by_fields()
.filter(id__in=row_ids),
RowsForUpdate, self.get_rows(model, row_ids).select_for_update(of=("self",))
)
def move_row_by_id(
@ -1906,7 +1911,6 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
before_return = before_rows_update.send(
self, rows=[row], user=user, table=table, model=model, updated_field_ids=[]
)
before_rows_values = serialize_rows_for_response([row], model)
row.order = self.get_unique_orders_before_row(before_row, model)[0]
row.save()
@ -1954,7 +1958,6 @@ class RowHandler(metaclass=baserow_trace_methods(tracer)):
model=model,
before_return=before_return,
updated_field_ids=[],
before_rows_values=before_rows_values,
prepared_rows_values=None,
)

View file

@ -0,0 +1,25 @@
from baserow.contrib.database.data_providers.registries import (
database_data_provider_type_registry,
)
from baserow.core.formula.runtime_formula_context import RuntimeFormulaContext
class HumanReadableRowContext(RuntimeFormulaContext):
def __init__(self, row, exclude_field_ids=None):
if exclude_field_ids is None:
exclude_field_ids = []
model = row._meta.model
self.human_readable_row_values = {
f"field_{field['field'].id}": field["type"].get_human_readable_value(
getattr(row, field["name"]), field
)
for field in model._field_objects.values()
if field["field"].id not in exclude_field_ids
}
super().__init__()
@property
def data_provider_registry(self):
return database_data_provider_type_registry

View file

@ -8,6 +8,7 @@ before_rows_delete = Signal()
rows_created = Signal()
rows_updated = Signal()
rows_deleted = Signal()
rows_ai_values_generation_error = Signal()
row_orders_recalculated = Signal()

View file

@ -5,6 +5,7 @@ from baserow.contrib.database.api.rows.serializers import (
serialize_rows_for_response,
)
from baserow.contrib.database.webhooks.registries import WebhookEventType
from baserow.contrib.database.ws.rows.signals import serialize_rows_values
from .signals import rows_created, rows_deleted, rows_updated
@ -62,19 +63,11 @@ class RowsUpdatedEventType(RowsEventType):
signal = rows_updated
def get_payload(
self,
event_id,
webhook,
model,
table,
rows,
before_return,
before_rows_values,
**kwargs
self, event_id, webhook, model, table, rows, before_return, **kwargs
):
payload = super().get_payload(event_id, webhook, model, table, rows, **kwargs)
old_items = before_rows_values
old_items = dict(before_return)[serialize_rows_values]
if webhook.use_user_field_names:
old_items = remap_serialized_rows_to_user_field_names(old_items, model)

View file

@ -11,7 +11,10 @@ from baserow.contrib.database.rows import signals as row_signals
from baserow.contrib.database.table.models import GeneratedTableModel
from baserow.contrib.database.views.handler import PublicViewRows, ViewHandler
from baserow.contrib.database.views.registries import view_type_registry
from baserow.contrib.database.ws.rows.signals import RealtimeRowMessages
from baserow.contrib.database.ws.rows.signals import (
RealtimeRowMessages,
serialize_rows_values,
)
from baserow.core.telemetry.utils import baserow_trace
from baserow.ws.registries import page_registry
@ -148,18 +151,10 @@ def public_before_rows_update(
@receiver(row_signals.rows_updated)
@baserow_trace(tracer)
def public_rows_updated(
sender,
rows,
user,
table,
model,
before_return,
updated_field_ids,
before_rows_values,
**kwargs
sender, rows, user, table, model, before_return, updated_field_ids, **kwargs
):
before_return_dict = dict(before_return)[public_before_rows_update]
serialized_old_rows = before_rows_values
serialized_old_rows = dict(before_return)[serialize_rows_values]
serialized_updated_rows = serialize_rows_for_response(rows, model)
old_row_public_views: List[PublicViewRows] = before_return_dict[

View file

@ -7,6 +7,7 @@ from baserow.contrib.database.api.rows.serializers import (
RowHistorySerializer,
RowSerializer,
get_row_serializer_class,
serialize_rows_for_response,
)
from baserow.contrib.database.rows import signals as row_signals
from baserow.contrib.database.rows.registries import row_metadata_registry
@ -14,6 +15,13 @@ from baserow.contrib.database.table.models import GeneratedTableModel
from baserow.ws.registries import page_registry
@receiver(row_signals.before_rows_update)
def serialize_rows_values(
sender, rows, user, table, model, updated_field_ids, **kwargs
):
return serialize_rows_for_response(rows, model)
@receiver(row_signals.rows_created)
def rows_created(
sender,
@ -57,10 +65,10 @@ def rows_updated(
model,
before_return,
updated_field_ids,
before_rows_values,
**kwargs,
):
table_page_type = page_registry.get("table")
before_rows_values = dict(before_return)[serialize_rows_values]
transaction.on_commit(
lambda: table_page_type.broadcast(
RealtimeRowMessages.rows_updated(
@ -79,6 +87,26 @@ def rows_updated(
)
@receiver(row_signals.rows_ai_values_generation_error)
def rows_ai_values_generation_error(
sender, user, rows, field, table, error_message, **kwargs
):
table_page_type = page_registry.get("table")
transaction.on_commit(
lambda: table_page_type.broadcast(
{
"type": "rows_ai_values_generation_error",
"field_id": field.id,
"table_id": table.id,
"row_ids": [row.id for row in rows],
"error": error_message,
},
getattr(user, "web_socket_id", None),
table_id=table.id,
)
)
@receiver(row_signals.before_rows_delete)
def before_rows_delete(sender, rows, user, table, model, **kwargs):
return get_row_serializer_class(model, RowSerializer, is_response=True)(

View file

@ -296,6 +296,17 @@ class CoreConfig(AppConfig):
)
notification_type_registry.register(BaserowVersionUpgradeNotificationType())
from baserow.core.generative_ai.generative_ai_models import (
OllamaGenerativeAIModelType,
OpenAIGenerativeAIModelType,
)
from baserow.core.generative_ai.registries import (
generative_ai_model_type_registry,
)
generative_ai_model_type_registry.register(OpenAIGenerativeAIModelType())
generative_ai_model_type_registry.register(OllamaGenerativeAIModelType())
# Must import the Posthog signal, otherwise it won't work.
import baserow.core.posthog # noqa: F403, F401

View file

@ -0,0 +1,17 @@
from baserow.core.exceptions import InstanceTypeDoesNotExist
class GenerativeAITypeDoesNotExist(InstanceTypeDoesNotExist):
"""Raised when trying to get a generative AI type that does not exist."""
class ModelDoesNotBelongToType(Exception):
"""Raised when trying to get a model that does not belong to the type."""
def __init__(self, model_name, *args, **kwargs):
self.model_name = model_name
super().__init__(*args, **kwargs)
class GenerativeAIPromptError(Exception):
"""Raised when an error occurs while prompting the model."""

View file

@ -0,0 +1,60 @@
from django.conf import settings
from ollama import Client as OllamaClient
from ollama import RequestError as OllamaRequestError
from ollama import ResponseError as OllamaResponseError
from openai import APIStatusError as OpenAIAPIStatusError
from openai import OpenAI, OpenAIError
from baserow.core.generative_ai.exceptions import GenerativeAIPromptError
from .registries import GenerativeAIModelType
class OpenAIGenerativeAIModelType(GenerativeAIModelType):
type = "openai"
def is_enabled(self):
return bool(settings.BASEROW_OPENAI_API_KEY) and bool(self.get_enabled_models())
def get_enabled_models(self):
return settings.BASEROW_OPENAI_MODELS
def prompt(self, model, prompt):
client = OpenAI(
api_key=settings.BASEROW_OPENAI_API_KEY,
organization=settings.BASEROW_OPENAI_ORGANIZATION,
)
try:
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
model=model,
stream=False,
)
except (OpenAIError, OpenAIAPIStatusError) as exc:
raise GenerativeAIPromptError(str(exc)) from exc
return chat_completion.choices[0].message.content
class OllamaGenerativeAIModelType(GenerativeAIModelType):
type = "ollama"
def is_enabled(self):
return bool(settings.BASEROW_OLLAMA_HOST) and bool(self.get_enabled_models())
def get_enabled_models(self):
return settings.BASEROW_OLLAMA_MODELS
def prompt(self, model, prompt):
client = OllamaClient(host=settings.BASEROW_OLLAMA_HOST)
try:
response = client.generate(model=model, prompt=prompt, stream=False)
except (OllamaRequestError, OllamaResponseError) as exc:
raise GenerativeAIPromptError(str(exc)) from exc
return response["response"]

View file

@ -0,0 +1,31 @@
from baserow.core.registry import Instance, Registry
from .exceptions import GenerativeAITypeDoesNotExist
class GenerativeAIModelType(Instance):
def is_enabled(self):
return False
def get_enabled_models(self):
return []
def prompt(self, model, prompt):
raise NotImplementedError("The prompt function must be implemented.")
class GenerativeAIModelTypeRegistry(Registry):
name = "generative_ai_model_type"
does_not_exist_exception_class = GenerativeAITypeDoesNotExist
def get_models_per_type(self):
return {
key: value.get_enabled_models()
for key, value in self.registry.items()
if value.is_enabled()
}
generative_ai_model_type_registry: GenerativeAIModelTypeRegistry = (
GenerativeAIModelTypeRegistry()
)

View file

@ -9,6 +9,7 @@ from .domain import DomainFixtures
from .element import ElementFixtures
from .field import FieldFixtures
from .file_import import FileImportFixtures
from .generative_ai import GenerativeAIFixtures
from .integration import IntegrationFixtures
from .job import JobFixtures
from .notifications import NotificationsFixture
@ -59,6 +60,7 @@ class Fixtures(
UserSourceFixtures,
AppAuthProviderFixtures,
UserSourceUserFixtures,
GenerativeAIFixtures,
):
def __init__(self, fake=None):
self.fake = fake

View file

@ -5,6 +5,7 @@ from baserow.contrib.database.fields.dependencies.handler import FieldDependency
from baserow.contrib.database.fields.field_cache import FieldCache
from baserow.contrib.database.fields.field_types import AutonumberFieldType
from baserow.contrib.database.fields.models import (
AIField,
AutonumberField,
BooleanField,
CreatedByField,
@ -405,6 +406,29 @@ class FieldFixtures:
self.set_test_field_kwarg_defaults(user, kwargs)
field = PasswordField.objects.create(**kwargs)
if create_field:
self.create_model_field(kwargs["table"], field)
return field
def create_ai_field(self, user=None, create_field=True, **kwargs):
self.set_test_field_kwarg_defaults(user, kwargs)
# Register the fake generative AI model for testing purposes.
self.register_fake_generate_ai_type()
if "ai_generative_ai_type" not in kwargs:
kwargs["ai_generative_ai_type"] = "test_generative_ai"
if "ai_generative_ai_model" not in kwargs:
kwargs["ai_generative_ai_model"] = "test_1"
if "ai_prompt" not in kwargs:
kwargs[
"ai_prompt"
] = "'What is your purpose? Answer with a maximum of 10 words.'"
field = AIField.objects.create(**kwargs)
if create_field:
self.create_model_field(kwargs["table"], field)

View file

@ -0,0 +1,65 @@
from baserow.core.generative_ai.exceptions import GenerativeAIPromptError
class GenerativeAIFixtures:
def register_fake_generate_ai_type(self, **kwargs):
from baserow.core.generative_ai.registries import (
GenerativeAIModelType,
generative_ai_model_type_registry,
)
class TestGenerativeAINoModelType(GenerativeAIModelType):
type = "test_generative_ai_no_model"
def is_enabled(self):
return True
def get_enabled_models(self):
return []
def prompt(self, model, prompt):
return ""
class TestGenerativeAIModelType(GenerativeAIModelType):
type = "test_generative_ai"
def is_enabled(self):
return True
def get_enabled_models(self):
return ["test_1"]
def prompt(self, model, prompt):
return f"Generated: {prompt}"
class TestGenerativeAIModelTypePromptError(GenerativeAIModelType):
type = "test_generative_ai_prompt_error"
def is_enabled(self):
return True
def get_enabled_models(self):
return ["test_1"]
def prompt(self, model, prompt):
raise GenerativeAIPromptError("Test error")
if (
TestGenerativeAINoModelType.type
not in generative_ai_model_type_registry.registry
):
generative_ai_model_type_registry.register(TestGenerativeAINoModelType())
if (
TestGenerativeAIModelType.type
not in generative_ai_model_type_registry.registry
):
generative_ai_model_type_registry.register(TestGenerativeAIModelType())
if (
TestGenerativeAIModelTypePromptError.type
not in generative_ai_model_type_registry.registry
):
generative_ai_model_type_registry.register(
TestGenerativeAIModelTypePromptError()
)

View file

@ -96,6 +96,8 @@ def setup_interesting_test_table(
if database is None:
database = data_fixture.create_database_application(user=user)
data_fixture.register_fake_generate_ai_type()
file_suffix = file_suffix or ""
try:
@ -241,6 +243,7 @@ def setup_interesting_test_table(
],
"phone_number": "+4412345678",
"password": "test",
"ai": "I'm an AI.",
}
with freeze_time("2020-02-01 01:23"):

View file

@ -0,0 +1,268 @@
from unittest.mock import patch
from django.conf import settings
from django.shortcuts import reverse
import pytest
from rest_framework.status import (
HTTP_202_ACCEPTED,
HTTP_400_BAD_REQUEST,
HTTP_404_NOT_FOUND,
)
from baserow.contrib.database.rows.handler import RowHandler
@pytest.mark.django_db
@pytest.mark.field_ai
def test_generate_ai_field_value_view_field_does_not_exist(data_fixture, api_client):
data_fixture.register_fake_generate_ai_type()
user, token = data_fixture.create_user_and_token(
email="test@test.nl", password="password", first_name="Test1"
)
database = data_fixture.create_database_application(user=user, name="database")
table = data_fixture.create_database_table(name="table", database=database)
field = data_fixture.create_ai_field(table=table, name="ai")
rows = RowHandler().create_rows(
user,
table,
rows_values=[{}],
)
response = api_client.post(
reverse(
"api:database:fields:async_generate_ai_field_values",
kwargs={"field_id": 0},
),
{"row_ids": [rows[0].id]},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
assert response.status_code == HTTP_404_NOT_FOUND
assert response.json()["error"] == "ERROR_FIELD_DOES_NOT_EXIST"
@pytest.mark.django_db
@pytest.mark.field_ai
def test_generate_ai_field_value_view_row_does_not_exist(data_fixture, api_client):
data_fixture.register_fake_generate_ai_type()
user, token = data_fixture.create_user_and_token(
email="test@test.nl", password="password", first_name="Test1"
)
database = data_fixture.create_database_application(user=user, name="database")
table = data_fixture.create_database_table(name="table", database=database)
field = data_fixture.create_ai_field(table=table, name="ai")
rows = RowHandler().create_rows(
user,
table,
rows_values=[{}],
)
response = api_client.post(
reverse(
"api:database:fields:async_generate_ai_field_values",
kwargs={"field_id": field.id},
),
{"row_ids": [0]},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
assert response.status_code == HTTP_404_NOT_FOUND
assert response.json()["error"] == "ERROR_ROW_DOES_NOT_EXIST"
@pytest.mark.django_db
@pytest.mark.field_ai
def test_generate_ai_field_value_view_user_not_in_workspace(data_fixture, api_client):
data_fixture.register_fake_generate_ai_type()
user, token = data_fixture.create_user_and_token(
email="test@test.nl", password="password", first_name="Test1"
)
user_2, token_2 = data_fixture.create_user_and_token(
email="test2@test.nl", password="password", first_name="Test1"
)
database = data_fixture.create_database_application(user=user, name="database")
table = data_fixture.create_database_table(name="table", database=database)
field = data_fixture.create_ai_field(table=table, name="ai")
rows = RowHandler().create_rows(
user,
table,
rows_values=[{}],
)
response = api_client.post(
reverse(
"api:database:fields:async_generate_ai_field_values",
kwargs={"field_id": field.id},
),
{"row_ids": [rows[0].id]},
format="json",
HTTP_AUTHORIZATION=f"JWT {token_2}",
)
assert response.status_code == HTTP_400_BAD_REQUEST
assert response.json()["error"] == "ERROR_USER_NOT_IN_GROUP"
@pytest.mark.django_db
@pytest.mark.field_ai
def test_generate_ai_field_value_view_generative_ai_does_not_exist(
data_fixture, api_client
):
data_fixture.register_fake_generate_ai_type()
user, token = data_fixture.create_user_and_token(
email="test@test.nl", password="password", first_name="Test1"
)
database = data_fixture.create_database_application(user=user, name="database")
table = data_fixture.create_database_table(name="table", database=database)
field = data_fixture.create_ai_field(
table=table, name="ai", ai_generative_ai_type="does_not_exist"
)
rows = RowHandler().create_rows(
user,
table,
rows_values=[{}],
)
response = api_client.post(
reverse(
"api:database:fields:async_generate_ai_field_values",
kwargs={"field_id": field.id},
),
{"row_ids": [rows[0].id]},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
assert response.status_code == HTTP_400_BAD_REQUEST
assert response.json()["error"] == "ERROR_GENERATIVE_AI_DOES_NOT_EXIST"
@pytest.mark.django_db
@pytest.mark.field_ai
def test_generate_ai_field_value_view_generative_ai_model_does_not_belong_to_type(
data_fixture, api_client
):
data_fixture.register_fake_generate_ai_type()
user, token = data_fixture.create_user_and_token(
email="test@test.nl", password="password", first_name="Test1"
)
database = data_fixture.create_database_application(user=user, name="database")
table = data_fixture.create_database_table(name="table", database=database)
field = data_fixture.create_ai_field(
table=table, name="ai", ai_generative_ai_model="does_not_exist"
)
rows = RowHandler().create_rows(
user,
table,
rows_values=[
{},
],
)
response = api_client.post(
reverse(
"api:database:fields:async_generate_ai_field_values",
kwargs={"field_id": field.id},
),
{"row_ids": [rows[0].id]},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
assert response.status_code == HTTP_400_BAD_REQUEST
assert response.json()["error"] == "ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE"
@pytest.mark.django_db
@pytest.mark.field_ai
@patch("baserow.contrib.database.fields.tasks.generate_ai_values_for_rows.apply")
def test_generate_ai_field_value_view_generative_ai(
patched_generate_ai_values_for_rows, data_fixture, api_client
):
data_fixture.register_fake_generate_ai_type()
user, token = data_fixture.create_user_and_token(
email="test@test.nl", password="password", first_name="Test1"
)
database = data_fixture.create_database_application(user=user, name="database")
table = data_fixture.create_database_table(name="table", database=database)
field = data_fixture.create_ai_field(table=table, name="ai", ai_prompt="'Hello'")
rows = RowHandler().create_rows(
user,
table,
rows_values=[{}],
)
assert patched_generate_ai_values_for_rows.call_count == 0
response = api_client.post(
reverse(
"api:database:fields:async_generate_ai_field_values",
kwargs={"field_id": field.id},
),
{"row_ids": [rows[0].id]},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
assert response.status_code == HTTP_202_ACCEPTED
assert patched_generate_ai_values_for_rows.call_count == 1
@pytest.mark.django_db
@pytest.mark.field_ai
def test_batch_generate_ai_field_value_limit(api_client, data_fixture):
data_fixture.register_fake_generate_ai_type()
user, token = data_fixture.create_user_and_token()
table = data_fixture.create_database_table(user=user)
field = data_fixture.create_ai_field(table=table, name="ai", ai_prompt="'Hello'")
rows = RowHandler().create_rows(
user,
table,
rows_values=[{}] * (settings.BATCH_ROWS_SIZE_LIMIT + 1),
)
row_ids = [row.id for row in rows]
# BATCH_ROWS_SIZE_LIMIT rows are allowed
response = api_client.post(
reverse(
"api:database:fields:async_generate_ai_field_values",
kwargs={"field_id": field.id},
),
{"row_ids": row_ids[: settings.BATCH_ROWS_SIZE_LIMIT]},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
assert response.status_code == HTTP_202_ACCEPTED
# BATCH_ROWS_SIZE_LIMIT + 1 rows are not allowed
response = api_client.post(
reverse(
"api:database:fields:async_generate_ai_field_values",
kwargs={"field_id": field.id},
),
{"row_ids": row_ids},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
assert response.status_code == HTTP_400_BAD_REQUEST
assert response.json()["error"] == "ERROR_REQUEST_BODY_VALIDATION"
assert response.json()["detail"] == {
"row_ids": [
{
"code": "max_length",
"error": f"Ensure this field has no more than"
f" {settings.BATCH_ROWS_SIZE_LIMIT} elements.",
},
],
}

View file

@ -375,6 +375,7 @@ def test_get_row_serializer_with_user_field_names(data_fixture):
"uuid": "00000000-0000-4000-8000-000000000003",
"autonumber": 2,
"password": True,
"ai": "I'm an AI.",
}
)
)

View file

@ -233,12 +233,12 @@ def test_can_export_every_interesting_different_field_to_csv(
"phone_number,formula_text,formula_int,formula_bool,formula_decimal,formula_dateinterval,"
"formula_date,formula_singleselect,formula_email,formula_link_with_label,"
"formula_link_url_only,formula_multipleselect,count,rollup,lookup,uuid,"
"autonumber,password\r\n"
"autonumber,password,ai\r\n"
"1,,,,,,,,,0,False,,,,,,,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,"
"02/01/2021 13:00,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,02/01/2021 13:00,"
"user@example.com,user@example.com,,,,,,,,,,,,,,,,,,,test FORMULA,1,True,33.3333333333,"
"1d 0:00,2020-01-01,,,label (https://google.com),https://google.com,,0,0.000,,"
"00000000-0000-4000-8000-000000000002,1,\r\n"
"00000000-0000-4000-8000-000000000002,1,,\r\n"
"2,text,long_text,https://www.google.com,test@example.com,-1,1,-1.2,1.2,3,True,"
"02/01/2020 01:23,02/01/2020,01/02/2020 01:23,01/02/2020,01/02/2020 02:23,"
"01/02/2020 02:23,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,"
@ -253,7 +253,7 @@ def test_can_export_every_interesting_different_field_to_csv(
'"user2@example.com,user3@example.com",\'+4412345678,test FORMULA,1,True,33.3333333333,'
"1d 0:00,2020-01-01,A,test@example.com,label (https://google.com),https://google.com,"
'"D,C,E",3,-122.222,"linked_row_1,linked_row_2,",00000000-0000-4000-8000-000000000003,'
"2,True\r\n"
"2,True,I'm an AI.\r\n"
)
assert contents == expected

View file

@ -0,0 +1,139 @@
from unittest.mock import patch
import pytest
from baserow.contrib.database.fields.tasks import generate_ai_values_for_rows
from baserow.contrib.database.rows.handler import RowHandler
from baserow.core.generative_ai.exceptions import GenerativeAIPromptError
@pytest.mark.django_db
@pytest.mark.field_ai
@patch("baserow.contrib.database.rows.signals.rows_updated.send")
def test_generate_ai_field_value_view_generative_ai(patched_rows_updated, data_fixture):
data_fixture.register_fake_generate_ai_type()
user = data_fixture.create_user(
email="test@test.nl", password="password", first_name="Test1"
)
database = data_fixture.create_database_application(user=user, name="database")
table = data_fixture.create_database_table(name="table", database=database)
field = data_fixture.create_ai_field(table=table, name="ai", ai_prompt="'Hello'")
rows = RowHandler().create_rows(user, table, rows_values=[{}])
assert patched_rows_updated.call_count == 0
generate_ai_values_for_rows(user.id, field.id, [rows[0].id])
assert patched_rows_updated.call_count == 1
updated_row = patched_rows_updated.call_args[1]["rows"][0]
assert getattr(updated_row, field.db_column) == "Generated: Hello"
assert patched_rows_updated.call_args[1]["updated_field_ids"] == set([field.id])
@pytest.mark.django_db
@pytest.mark.field_ai
@patch("baserow.contrib.database.rows.signals.rows_updated.send")
def test_generate_ai_field_value_view_generative_ai_parse_formula(
patched_rows_updated, data_fixture
):
data_fixture.register_fake_generate_ai_type()
user = data_fixture.create_user(
email="test@test.nl", password="password", first_name="Test1"
)
database = data_fixture.create_database_application(user=user, name="database")
table = data_fixture.create_database_table(name="table", database=database)
firstname = data_fixture.create_text_field(table=table, name="firstname")
lastname = data_fixture.create_text_field(table=table, name="lastname")
formula = f"concat('Hello ', get('fields.field_{firstname.id}'), ' ', get('fields.field_{lastname.id}'))"
field = data_fixture.create_ai_field(table=table, name="ai", ai_prompt=formula)
rows = RowHandler().create_rows(
user,
table,
rows_values=[
{f"field_{firstname.id}": "Bram", f"field_{lastname.id}": "Wiepjes"},
],
)
assert patched_rows_updated.call_count == 0
generate_ai_values_for_rows(user.id, field.id, [rows[0].id])
assert patched_rows_updated.call_count == 1
updated_row = patched_rows_updated.call_args[1]["rows"][0]
assert getattr(updated_row, field.db_column) == "Generated: Hello Bram Wiepjes"
assert patched_rows_updated.call_args[1]["updated_field_ids"] == set([field.id])
@pytest.mark.django_db
@pytest.mark.field_ai
@patch("baserow.contrib.database.rows.signals.rows_updated.send")
def test_generate_ai_field_value_view_generative_ai_invalid_field(
patched_rows_updated, data_fixture
):
data_fixture.register_fake_generate_ai_type()
user = data_fixture.create_user(
email="test@test.nl", password="password", first_name="Test1"
)
database = data_fixture.create_database_application(user=user, name="database")
table = data_fixture.create_database_table(name="table", database=database)
firstname = data_fixture.create_text_field(table=table, name="firstname")
formula = "concat('Hello ', get('fields.field_0'))"
field = data_fixture.create_ai_field(table=table, name="ai", ai_prompt=formula)
rows = RowHandler().create_rows(
user,
table,
rows_values=[{f"field_{firstname.id}": "Bram"}],
)
assert patched_rows_updated.call_count == 0
generate_ai_values_for_rows(user.id, field.id, [rows[0].id])
assert patched_rows_updated.call_count == 1
updated_row = patched_rows_updated.call_args[1]["rows"][0]
assert getattr(updated_row, field.db_column) == "Generated: Hello "
@pytest.mark.django_db
@pytest.mark.field_ai
@patch("baserow.contrib.database.rows.signals.rows_ai_values_generation_error.send")
@patch("baserow.contrib.database.rows.signals.rows_updated.send")
def test_generate_ai_field_value_view_generative_ai_invalid_prompt(
patched_rows_updated, patched_rows_ai_values_generation_error, data_fixture
):
data_fixture.register_fake_generate_ai_type()
user = data_fixture.create_user(
email="test@test.nl", password="password", first_name="Test1"
)
database = data_fixture.create_database_application(user=user, name="database")
table = data_fixture.create_database_table(name="table", database=database)
firstname = data_fixture.create_text_field(table=table, name="firstname")
formula = "concat('Hello ', get('fields.field_0'))"
field = data_fixture.create_ai_field(
table=table,
name="ai",
ai_generative_ai_type="test_generative_ai_prompt_error",
ai_prompt=formula,
)
rows = RowHandler().create_rows(
user,
table,
rows_values=[{f"field_{firstname.id}": "Bram"}],
)
assert patched_rows_ai_values_generation_error.call_count == 0
with pytest.raises(GenerativeAIPromptError):
generate_ai_values_for_rows(user.id, field.id, [rows[0].id])
assert patched_rows_updated.call_count == 0
assert patched_rows_ai_values_generation_error.call_count == 1
call_args_rows = patched_rows_ai_values_generation_error.call_args[1]["rows"]
assert len(call_args_rows) == 1
assert rows[0].id == call_args_rows[0].id
assert patched_rows_ai_values_generation_error.call_args[1]["field"] == field
assert (
patched_rows_ai_values_generation_error.call_args[1]["error_message"]
== "Test error"
)

View file

@ -0,0 +1,240 @@
from django.shortcuts import reverse
import pytest
from rest_framework.status import HTTP_200_OK, HTTP_400_BAD_REQUEST
from baserow.contrib.database.fields.handler import FieldHandler
from baserow.contrib.database.fields.models import AIField
@pytest.mark.django_db
@pytest.mark.field_ai
def test_create_ai_field_type(data_fixture):
user = data_fixture.create_user()
table = data_fixture.create_database_table(user=user)
data_fixture.register_fake_generate_ai_type()
data_fixture.create_text_field(table=table, order=1, name="name")
handler = FieldHandler()
ai_field = handler.create_field(
user=user,
table=table,
type_name="ai",
name="ai_1",
ai_generative_ai_type="test_generative_ai",
ai_generative_ai_model="test_1",
ai_prompt="'Who are you?'",
)
assert ai_field.ai_generative_ai_type == "test_generative_ai"
assert ai_field.ai_generative_ai_model == "test_1"
assert ai_field.ai_prompt == "'Who are you?'"
assert len(AIField.objects.all()) == 1
@pytest.mark.django_db
@pytest.mark.field_ai
def test_update_ai_field_type(data_fixture):
user = data_fixture.create_user()
table = data_fixture.create_database_table(user=user)
field = data_fixture.create_ai_field(table=table, order=1, name="name")
handler = FieldHandler()
ai_field = handler.update_field(
user=user,
field=field,
name="ai_1",
ai_generative_ai_type="test_generative_ai",
ai_generative_ai_model="test_1",
ai_prompt="'Who are you?'",
)
assert ai_field.ai_generative_ai_type == "test_generative_ai"
assert ai_field.ai_generative_ai_model == "test_1"
assert ai_field.ai_prompt == "'Who are you?'"
@pytest.mark.django_db
@pytest.mark.field_ai
def test_delete_ai_field_type(data_fixture):
user = data_fixture.create_user()
table = data_fixture.create_database_table(user=user)
field = data_fixture.create_ai_field(
table=table,
order=1,
name="name",
ai_generative_ai_type="test_generative_ai",
ai_generative_ai_model="test_1",
ai_prompt="'Who are you?'",
)
handler = FieldHandler()
handler.delete_field(user=user, field=field)
assert len(AIField.objects.all()) == 0
@pytest.mark.django_db
@pytest.mark.field_ai
def test_create_ai_field_type_via_api(data_fixture, api_client):
user, token = data_fixture.create_user_and_token()
table = data_fixture.create_database_table(user=user)
data_fixture.register_fake_generate_ai_type()
data_fixture.create_text_field(table=table, order=1, name="name")
response = api_client.post(
reverse("api:database:fields:list", kwargs={"table_id": table.id}),
{
"name": "Test 1",
"type": "ai",
"ai_generative_ai_type": "test_generative_ai",
"ai_generative_ai_model": "test_1",
"ai_prompt": "'Who are you?'",
},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
response_json = response.json()
assert response.status_code == HTTP_200_OK
assert response_json["ai_generative_ai_type"] == "test_generative_ai"
assert response_json["ai_generative_ai_model"] == "test_1"
assert response_json["ai_prompt"] == "'Who are you?'"
@pytest.mark.django_db
@pytest.mark.field_ai
def test_create_ai_field_type_via_api_invalid_formula(data_fixture, api_client):
user, token = data_fixture.create_user_and_token()
table = data_fixture.create_database_table(user=user)
data_fixture.register_fake_generate_ai_type()
data_fixture.create_text_field(table=table, order=1, name="name")
response = api_client.post(
reverse("api:database:fields:list", kwargs={"table_id": table.id}),
{
"name": "Test 1",
"type": "ai",
"ai_generative_ai_type": "test_generative_ai",
"ai_generative_ai_model": "test_1",
"ai_prompt": "ffff;;s9(",
},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
assert response.status_code == HTTP_400_BAD_REQUEST
response_json = response.json()
assert response_json["error"] == "ERROR_REQUEST_BODY_VALIDATION"
assert response_json["detail"]["ai_prompt"][0]["code"] == "invalid"
@pytest.mark.django_db
@pytest.mark.field_ai
def test_create_ai_field_type_via_api_with_invalid_type(data_fixture, api_client):
user, token = data_fixture.create_user_and_token()
table = data_fixture.create_database_table(user=user)
data_fixture.register_fake_generate_ai_type()
data_fixture.create_text_field(table=table, order=1, name="name")
response = api_client.post(
reverse("api:database:fields:list", kwargs={"table_id": table.id}),
{
"name": "Test 1",
"type": "ai",
"ai_generative_ai_type": "does_not_exist",
"ai_generative_ai_model": "test_1",
"ai_prompt": "'Who are you?'",
},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
response_json = response.json()
assert response.status_code == HTTP_400_BAD_REQUEST
assert response_json["error"] == "ERROR_GENERATIVE_AI_DOES_NOT_EXIST"
@pytest.mark.django_db
@pytest.mark.field_ai
def test_create_ai_field_type_via_api_with_invalid_model(data_fixture, api_client):
user, token = data_fixture.create_user_and_token()
table = data_fixture.create_database_table(user=user)
data_fixture.register_fake_generate_ai_type()
data_fixture.create_text_field(table=table, order=1, name="name")
response = api_client.post(
reverse("api:database:fields:list", kwargs={"table_id": table.id}),
{
"name": "Test 1",
"type": "ai",
"ai_generative_ai_type": "test_generative_ai",
"ai_generative_ai_model": "does_not_exist",
"ai_prompt": "'Who are you?'",
},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
response_json = response.json()
assert response.status_code == HTTP_400_BAD_REQUEST
assert response_json["error"] == "ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE"
@pytest.mark.django_db
@pytest.mark.field_ai
def test_update_ai_field_type_via_api_with_invalid_type(data_fixture, api_client):
user, token = data_fixture.create_user_and_token()
table = data_fixture.create_database_table(user=user)
data_fixture.register_fake_generate_ai_type()
field = data_fixture.create_ai_field(table=table, order=1, name="name")
response = api_client.patch(
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
{"ai_generative_ai_type": "does_not_exist"},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
response_json = response.json()
assert response.status_code == HTTP_400_BAD_REQUEST
assert response_json["error"] == "ERROR_GENERATIVE_AI_DOES_NOT_EXIST"
@pytest.mark.django_db
@pytest.mark.field_ai
def test_update_ai_field_type_via_api_with_invalid_model(data_fixture, api_client):
user, token = data_fixture.create_user_and_token()
table = data_fixture.create_database_table(user=user)
data_fixture.register_fake_generate_ai_type()
field = data_fixture.create_ai_field(table=table, order=1, name="name")
response = api_client.patch(
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
{"ai_generative_ai_model": "does_not_exist"},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
response_json = response.json()
assert response.status_code == HTTP_400_BAD_REQUEST
assert response_json["error"] == "ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE"
@pytest.mark.django_db
@pytest.mark.field_ai
def test_update_ai_field_type_via_api_with_valid_model(data_fixture, api_client):
user, token = data_fixture.create_user_and_token()
table = data_fixture.create_database_table(user=user)
data_fixture.register_fake_generate_ai_type()
field = data_fixture.create_ai_field(table=table, order=1, name="name")
response = api_client.patch(
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
{"ai_generative_ai_model": "test_1"},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
assert response.status_code == HTTP_200_OK
response = api_client.patch(
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
{"ai_generative_ai_type": "test_generative_ai"},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
assert response.status_code == HTTP_200_OK

View file

@ -26,6 +26,7 @@ from baserow.contrib.database.fields.field_helpers import (
construct_all_possible_field_kwargs,
)
from baserow.contrib.database.fields.field_types import (
AIFieldType,
AutonumberFieldType,
BooleanFieldType,
CountFieldType,
@ -363,6 +364,13 @@ def test_field_conversion_autonumber(data_fixture):
_test_can_convert_between_fields(data_fixture, AutonumberFieldType.type)
@pytest.mark.field_ai
@pytest.mark.disabled_in_ci
@pytest.mark.django_db
def test_field_conversion_ai(data_fixture):
_test_can_convert_between_fields(data_fixture, AIFieldType.type)
@pytest.mark.django_db
def test_get_field(data_fixture):
user = data_fixture.create_user()

View file

@ -113,6 +113,7 @@ def test_fill_table_fields_with_add_all_fields(data_fixture):
user = data_fixture.create_user()
table = data_fixture.create_database_table(user=user)
text_field = data_fixture.create_text_field(user=user, table=table)
data_fixture.register_fake_generate_ai_type()
model = table.get_model()

View file

@ -1,9 +1,9 @@
import pytest
from baserow.contrib.database.api.rows.serializers import serialize_rows_for_response
from baserow.contrib.database.fields.handler import FieldHandler
from baserow.contrib.database.rows.handler import RowHandler
from baserow.contrib.database.webhooks.registries import webhook_event_type_registry
from baserow.contrib.database.ws.rows.signals import serialize_rows_values
@pytest.mark.django_db()
@ -92,8 +92,11 @@ def test_rows_updated_event_type(data_fixture):
row = model.objects.create(**{f"field_{text_field.id}": "Old Test value"})
getattr(row, f"field_{link_row_field.id}").add(i1.id)
before_return = {}
before_rows_values = serialize_rows_for_response([row], model)
before_return = {
serialize_rows_values: serialize_rows_values(
None, [row], user, table, model, [text_field.id]
)
}
row = RowHandler().update_row_by_id(
user=user,
@ -116,7 +119,6 @@ def test_rows_updated_event_type(data_fixture):
table=table,
rows=[row],
before_return=before_return,
before_rows_values=before_rows_values,
)
assert payload == {
"table_id": table.id,
@ -151,7 +153,6 @@ def test_rows_updated_event_type(data_fixture):
table=table,
rows=[row],
before_return=before_return,
before_rows_values=before_rows_values,
)
assert payload == {
"table_id": table.id,

View file

@ -1724,6 +1724,13 @@ def test_local_baserow_table_service_generate_schema_with_interesting_test_table
"metadata": {},
"type": "boolean",
},
field_db_column_by_name["ai"]: {
"title": "ai",
"default": None,
"original_type": "ai",
"metadata": {},
"type": "string",
},
"id": {"metadata": {}, "type": "number", "title": "Id"},
}

View file

@ -0,0 +1,7 @@
{
"type": "feature",
"message": "Introduced a new AI field type",
"issue_number": null,
"bullet_points": [],
"created_at": "2024-04-02"
}

View file

@ -164,6 +164,11 @@ x-backend-variables: &backend-variables
BASEROW_BUILDER_DOMAINS:
SENTRY_DSN:
SENTRY_BACKEND_DSN:
BASEROW_OPENAI_API_KEY:
BASEROW_OPENAI_ORGANIZATION:
BASEROW_OPENAI_MODELS:
BASEROW_OLLAMA_HOST:
BASEROW_OLLAMA_MODELS:
services:

View file

@ -105,7 +105,8 @@ def test_can_export_every_interesting_different_field_to_json(
"lookup": [],
"uuid": "00000000-0000-4000-8000-000000000002",
"autonumber": 1,
"password": ""
"password": "",
"ai": ""
},
{
"id": 2,
@ -221,7 +222,8 @@ def test_can_export_every_interesting_different_field_to_json(
],
"uuid": "00000000-0000-4000-8000-000000000003",
"autonumber": 2,
"password": true
"password": true,
"ai": "I'm an AI."
}
]
"""
@ -371,6 +373,7 @@ def test_can_export_every_interesting_different_field_to_xml(
<uuid>00000000-0000-4000-8000-000000000002</uuid>
<autonumber>1</autonumber>
<password/>
<ai/>
</row>
<row>
<id>2</id>
@ -487,6 +490,7 @@ def test_can_export_every_interesting_different_field_to_xml(
<uuid>00000000-0000-4000-8000-000000000003</uuid>
<autonumber>2</autonumber>
<password>true</password>
<ai>I'm an AI.</ai>
</row>
</rows>
"""

View file

@ -20,6 +20,7 @@
:visible-fields="visibleCardFields"
:hidden-fields="hiddenFields"
:show-hidden-fields="showHiddenFieldsInRowModal"
:all-fields-in-table="fields"
@toggle-hidden-fields-visibility="
showHiddenFieldsInRowModal = !showHiddenFieldsInRowModal
"
@ -35,7 +36,7 @@
:database="database"
:table="table"
:view="view"
:fields="fields"
:all-fields-in-table="fields"
:primary-is-sortable="false"
:visible-fields="visibleCardFields"
:hidden-fields="hiddenFields"

View file

@ -80,6 +80,7 @@
:visible-fields="cardFields"
:hidden-fields="hiddenFields"
:show-hidden-fields="showHiddenFieldsInRowModal"
:all-fields-in-table="fields"
@toggle-hidden-fields-visibility="
showHiddenFieldsInRowModal = !showHiddenFieldsInRowModal
"
@ -95,7 +96,7 @@
:database="database"
:table="table"
:view="view"
:fields="fields"
:all-fields-in-table="fields"
:primary-is-sortable="true"
:visible-fields="cardFields"
:hidden-fields="hiddenFields"

View file

@ -120,7 +120,8 @@
"singleSelectDropdown": "Dropdown",
"singleSelectRadios": "Radios",
"autonumber": "Autonumber",
"password": "Password"
"password": "Password",
"ai": "AI prompt"
},
"fieldErrors": {
"invalidNumber": "Invalid number",
@ -311,13 +312,16 @@
"disabledPasswordProviderMessage": "Please use another authentication provider.",
"maxLocksPerTransactionExceededTitle": "PostgreSQL issue detected",
"maxLocksPerTransactionExceededDescription": "Baserow attempted to permanently delete the trashed items, but exceeded the available locks specified in `max_locks_per_transaction`.",
"disabledPasswordProviderMessage": "Please use another authentication provider.",
"lastAdminTitle": "Can't remove last workspace admin",
"lastAdminMessage": "A workspace has to have at least one admin.",
"adminAlreadyExistsTitle": "Can't use that username",
"adminAlreadyExistsDescription": "That username can't be used because it already exists.",
"cannotCreateFieldTypeTitle": "Can't create field",
"cannotCreateFieldTypeDescription": "The requested field type cannot be created at the moment. It might be just a temporary problem with older tables. Please try again later."
"cannotCreateFieldTypeDescription": "The requested field type cannot be created at the moment. It might be just a temporary problem with older tables. Please try again later.",
"generativeAIDoesNotExistTitle": "Generative AI does not exist",
"generativeAIDoesNotExistDescription": "The generative AI model does not exist.",
"modelDoesNotBelongToTypeTitle": "The selected model does not belong to the AI Type",
"modelDoesNotBelongToTypeDescription": "The selected model does not belong to the selected AI type."
},
"importerType": {
"csv": "Import a CSV file",

View file

@ -22,6 +22,7 @@
:nodes="nodes"
:node-selected="nodeSelected"
:loading="dataExplorerLoading"
:application-context="applicationContext"
@node-selected="dataExplorerItemSelected"
@node-toggled="editor.commands.focus()"
@focusin="dataExplorerFocused = true"
@ -52,7 +53,10 @@ export default {
},
provide() {
// Provide the application context to all formula components
return { applicationContext: this.applicationContext }
return {
applicationContext: this.applicationContext,
dataProviders: this.dataProviders,
}
},
props: {
value: {

View file

@ -40,11 +40,16 @@ export default {
NodeViewWrapper,
},
mixins: [formulaComponent],
inject: ['applicationContext'],
inject: ['applicationContext', 'dataProviders'],
data() {
return { nodes: [], pathParts: [] }
},
computed: {
availableData() {
return Object.values(this.dataProviders).map((dataProvider) =>
dataProvider.getNodes(this.applicationContext)
)
},
isInvalid() {
return this.findNode(this.nodes, _.toPath(this.path)) === null
},
@ -58,7 +63,10 @@ export default {
return _.toPath(this.path)
},
dataProviderType() {
return this.$registry.get('builderDataProvider', this.rawPathParts[0])
const pathParts = this.rawPathParts
return this.dataProviders.find(
(dataProvider) => dataProvider.type === pathParts[0]
)
},
},
mounted() {

View file

@ -147,6 +147,14 @@ export class ClientErrorMap {
app.i18n.t('clientHandler.lastAdminTitle'),
app.i18n.t('clientHandler.lastAdminMessage')
),
ERROR_GENERATIVE_AI_DOES_NOT_EXIST: new ResponseErrorMessage(
app.i18n.t('clientHandler.generativeAIDoesNotExistTitle'),
app.i18n.t('clientHandler.generativeAIDoesNotExistDescription')
),
ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE: new ResponseErrorMessage(
app.i18n.t('clientHandler.modelDoesNotBelongToTypeTitle'),
app.i18n.t('clientHandler.modelDoesNotBelongToTypeDescription')
),
}
}

View file

@ -30,6 +30,7 @@
:table="table"
:view="view"
:forced-type="singleSelectFieldType"
:all-fields-in-table="fields"
@field-created="$event.callback()"
></CreateFieldContext>
</div>

View file

@ -11,6 +11,7 @@
:table="table"
:view="view"
:forced-type="forcedType"
:all-fields-in-table="allFieldsInTable"
@submitted="submit"
@keydown-enter="$refs.submitButton.focus()"
>
@ -57,6 +58,10 @@ export default {
required: false,
default: false,
},
allFieldsInTable: {
type: Array,
required: true,
},
},
data() {
return {

View file

@ -32,7 +32,6 @@
</template>
<script>
import { mapGetters } from 'vuex'
import { notifyIf } from '@baserow/modules/core/utils/error'
import { createNewUndoRedoActionGroupId } from '@baserow/modules/database/utils/action'
import FieldService from '@baserow/modules/database/services/field'
@ -52,6 +51,10 @@ export default {
type: Object,
required: true,
},
allFieldsInTable: {
type: Array,
required: true,
},
},
data() {
return {
@ -62,14 +65,11 @@ export default {
},
computed: {
existingFieldName() {
return this.fields.map((field) => field.name)
return this.allFieldsInTable.map((field) => field.name)
},
formFieldTypeIsReadOnly() {
return this.$registry.get('field', this.fromField.type).isReadOnly
},
...mapGetters({
fields: 'field/getAll',
}),
},
methods: {
onDuplicationEnd() {

View file

@ -0,0 +1,141 @@
<template>
<div>
<div class="control">
<label class="control__label control__label--small">{{
$t('fieldAISubForm.AIType')
}}</label>
<div class="control__elements">
<Dropdown
v-model="values.ai_generative_ai_type"
class="dropdown--floating"
:class="{
'dropdown--error': $v.values.ai_generative_ai_type.$error,
}"
:fixed-items="true"
:show-search="false"
small
@hide="$v.values.ai_generative_ai_type.$touch()"
@change="$refs.aiModel.select(aIModelsPerType[0])"
>
<DropdownItem
v-for="aIType in aITypes"
:key="aIType"
:name="aIType"
:value="aIType"
/>
</Dropdown>
</div>
</div>
<div class="control">
<label class="control__label control__label--small">
{{ $t('fieldAISubForm.AIModel') }}
</label>
<div class="control__elements">
<Dropdown
ref="aiModel"
v-model="values.ai_generative_ai_model"
class="dropdown--floating"
:class="{
'dropdown--error': $v.values.ai_generative_ai_model.$error,
}"
:fixed-items="true"
:show-search="false"
small
@hide="$v.values.ai_generative_ai_model.$touch()"
>
<DropdownItem
v-for="aIType in aIModelsPerType"
:key="aIType"
:name="aIType"
:value="aIType"
/>
</Dropdown>
</div>
</div>
<div class="control">
<label class="control__label control__label--small">
{{ $t('fieldAISubForm.prompt') }}
</label>
<div class="control__elements">
<div style="max-width: 366px">
<FormulaInputField
v-model="values.ai_prompt"
:data-providers="dataProviders"
:application-context="applicationContext"
placeholder="What is Baserow?"
></FormulaInputField>
</div>
<div
v-if="$v.values.ai_prompt.$dirty && $v.values.ai_prompt.$error"
class="error"
>
{{ $t('error.requiredField') }}
</div>
</div>
</div>
</div>
</template>
<script>
import { mapGetters } from 'vuex'
import { required } from 'vuelidate/lib/validators'
import form from '@baserow/modules/core/mixins/form'
import fieldSubForm from '@baserow/modules/database/mixins/fieldSubForm'
import FormulaInputField from '@baserow/modules/core/components/formula/FormulaInputField'
export default {
name: 'FieldAISubForm',
components: { FormulaInputField },
mixins: [form, fieldSubForm],
data() {
return {
allowedValues: [
'ai_generative_ai_type',
'ai_generative_ai_model',
'ai_prompt',
],
values: {
ai_generative_ai_type: null,
ai_generative_ai_model: null,
ai_prompt: '',
},
}
},
computed: {
...mapGetters({
settings: 'settings/get',
}),
applicationContext() {
const context = {}
Object.defineProperty(context, 'fields', {
enumerable: true,
get: () =>
this.allFieldsInTable.filter((f) => {
const isNotThisField = f.id !== this.defaultValues.id
return isNotThisField
}),
})
return context
},
dataProviders() {
return [this.$registry.get('databaseDataProvider', 'fields')]
},
aITypes() {
return Object.keys(this.settings.generative_ai)
},
aIModelsPerType() {
return (
this.settings.generative_ai[this.values.ai_generative_ai_type] || []
)
},
},
validations: {
values: {
ai_generative_ai_type: { required },
ai_generative_ai_model: { required },
ai_prompt: { required },
},
},
}
</script>

View file

@ -36,6 +36,7 @@
:table="table"
:view="view"
:field="field"
:all-fields-in-table="allFieldsInTable"
@update="$emit('update', $event)"
@updated="$refs.context.hide()"
></UpdateFieldContext>
@ -94,6 +95,10 @@ export default {
type: Object,
required: true,
},
allFieldsInTable: {
type: Array,
required: true,
},
},
data() {
return {

View file

@ -1,7 +1,7 @@
<template>
<div>
<FieldSelectThroughFieldSubForm
:fields="fields"
:fields="allFieldsInTable"
:database="database"
:default-values="defaultValues"
></FieldSelectThroughFieldSubForm>
@ -27,12 +27,6 @@ export default {
database() {
return this.$store.getters['application/get'](this.table.database_id)
},
fields() {
// This part might fail in the future because we can't 100% depend on that the
// fields in the store are related to the component that renders this. An example
// is if you edit the field type in a row edit modal of a related table.
return this.$store.getters['field/getAll']
},
},
validations: {},
}

View file

@ -60,7 +60,10 @@
:icon="fieldType.iconClass"
:name="fieldType.getName()"
:value="fieldType.type"
:disabled="primary && !fieldType.canBePrimaryField"
:disabled="
(primary && !fieldType.canBePrimaryField) ||
!fieldType.isEnabled()
"
></DropdownItem>
</Dropdown>
<div v-if="$v.values.type.$error" class="error">
@ -76,6 +79,7 @@
:field-type="values.type"
:view="view"
:primary="primary"
:all-fields-in-table="allFieldsInTable"
:name="values.name"
:default-values="defaultValues"
@validate="$v.$touch"
@ -121,6 +125,10 @@ export default {
required: false,
default: null,
},
allFieldsInTable: {
type: Array,
required: true,
},
},
data() {
return {
@ -201,7 +209,10 @@ export default {
return !RESERVED_BASEROW_FIELD_NAMES.includes(param?.trim())
},
getFormComponent(type) {
return this.$registry.get('field', type).getFormComponent()
const fieldType = this.$registry.get('field', type)
if (fieldType.isEnabled()) {
return fieldType.getFormComponent()
}
},
showFieldTypesDropdown(target) {
this.$refs.fieldTypesDropdown.show(target)

View file

@ -9,6 +9,7 @@
:view="view"
:loading="refreshingFormula"
:formula-type-refresh-needed="formulaTypeRefreshNeeded"
:all-fields-in-table="allFieldsInTable"
@open-advanced-context="
$refs.advancedFormulaEditContext.openContext($event)
"
@ -30,7 +31,6 @@
<script>
import { required } from 'vuelidate/lib/validators'
import { mapGetters } from 'vuex'
import form from '@baserow/modules/core/mixins/form'
import { notifyIf } from '@baserow/modules/core/utils/error'
@ -69,9 +69,6 @@ export default {
}
},
computed: {
...mapGetters({
rawFields: 'field/getAll',
}),
localOrServerFormulaType() {
return (
this.mergedTypeOptions.array_formula_type ||
@ -79,7 +76,7 @@ export default {
)
},
fieldsUsableInFormula() {
return this.rawFields.filter((f) => {
return this.allFieldsInTable.filter((f) => {
const isNotThisField = f.id !== this.defaultValues.id
const canBeReferencedByFormulaField = this.$registry
.get('field', f.type)

View file

@ -1,7 +1,7 @@
<template>
<div>
<FieldSelectThroughFieldSubForm
:fields="fields"
:fields="allFieldsInTable"
:database="database"
:default-values="defaultValues"
@input="selectedThroughField = $event"
@ -20,6 +20,7 @@
:formula-type="targetFieldFormulaType"
:table="table"
:view="view"
:all-fields-in-table="allFieldsInTable"
>
</FormulaTypeSubForms>
</template>
@ -27,8 +28,6 @@
</template>
<script>
import { mapGetters } from 'vuex'
import form from '@baserow/modules/core/mixins/form'
import fieldSubForm from '@baserow/modules/database/mixins/fieldSubForm'
import FormulaTypeSubForms from '@baserow/modules/database/components/formula/FormulaTypeSubForms'
@ -66,12 +65,6 @@ export default {
}
return 'unknown'
},
...mapGetters({
// This part might fail in the future because we can't 100% depend on that the
// fields in the store are related to the component that renders this. An example
// is if you edit the field type in a row edit modal of a related table.
fields: 'field/getAll',
}),
},
validations: {},
methods: {

View file

@ -1,7 +1,7 @@
<template>
<div>
<FieldSelectThroughFieldSubForm
:fields="fields"
:fields="allFieldsInTable"
:database="database"
:default-values="defaultValues"
@input="selectedThroughField = $event"
@ -43,6 +43,7 @@
:formula-type="targetFieldFormulaType"
:table="table"
:view="view"
:all-fields-in-table="allFieldsInTable"
>
</FormulaTypeSubForms>
</template>
@ -53,7 +54,6 @@
</template>
<script>
import { mapGetters } from 'vuex'
import { required } from 'vuelidate/lib/validators'
import form from '@baserow/modules/core/mixins/form'
@ -99,12 +99,6 @@ export default {
(f) => f.isRollupCompatible()
)
},
...mapGetters({
// This part might fail in the future because we can't 100% depend on that the
// fields in the store are related to the component that renders this. An example
// is if you edit the field type in a row edit modal of a related table.
fields: 'field/getAll',
}),
},
validations: {
values: {

View file

@ -5,6 +5,7 @@
:view="view"
:force-typed="forcedType"
:use-action-group-id="true"
:all-fields-in-table="allFieldsInTable"
@field-created="$emit('field-created', $event)"
@field-created-callback-done="updateInsertedFieldOrder"
></CreateFieldContext>
@ -35,6 +36,10 @@ export default {
type: Object,
required: true,
},
allFieldsInTable: {
type: Array,
required: true,
},
},
data() {
return {

View file

@ -11,6 +11,7 @@
:view="view"
:default-values="field"
:primary="field.primary"
:all-fields-in-table="allFieldsInTable"
@submitted="submit"
>
<div
@ -23,7 +24,7 @@
type="submit"
class="button"
:class="{ 'button--loading': loading }"
:disabled="loading"
:disabled="loading || fieldTypeDisabled"
>
{{ $t('action.save') }}
</button>
@ -54,12 +55,21 @@ export default {
type: Object,
required: true,
},
allFieldsInTable: {
type: Array,
required: true,
},
},
data() {
return {
loading: false,
}
},
computed: {
fieldTypeDisabled() {
return !this.$registry.get('field', this.field.type).isEnabled()
},
},
watch: {
field() {
// If the field values are updated via an outside source, think of real time

View file

@ -40,6 +40,7 @@
:formula-type="formulaType"
:table="table"
:view="view"
:all-fields-in-table="allFieldsInTable"
>
</FormulaTypeSubForms>
</template>

View file

@ -6,6 +6,7 @@
:table="table"
:view="view"
:allow-set-number-negative="false"
:all-fields-in-table="allFieldsInTable"
>
</FieldNumberSubForm>
<FieldDateSubForm
@ -13,6 +14,7 @@
:default-values="defaultValues"
:table="table"
:view="view"
:all-fields-in-table="allFieldsInTable"
>
</FieldDateSubForm>
<FieldDurationSubForm
@ -20,6 +22,7 @@
:default-values="defaultValues"
:table="table"
:view="view"
:all-fields-in-table="allFieldsInTable"
>
</FieldDurationSubForm>
</div>

View file

@ -5,7 +5,7 @@
:database="database"
:table="table"
:rows="[]"
:fields="fields"
:all-fields-in-table="fields"
:visible-fields="fields"
:fields-sortable="fieldsSortable"
:can-modify-fields="canModifyFields"

View file

@ -18,6 +18,7 @@
:table="table"
:database="database"
:can-modify-fields="canModifyFields"
:all-fields-in-table="allFieldsInTable"
@field-updated="$emit('field-updated', $event)"
@field-deleted="$emit('field-deleted')"
@order-fields="$emit('order-fields', $event)"
@ -42,6 +43,7 @@
:view="view"
:database="database"
:can-modify-fields="canModifyFields"
:all-fields-in-table="allFieldsInTable"
@field-updated="$emit('field-updated', $event)"
@field-deleted="$emit('field-deleted')"
@toggle-field-visibility="$emit('toggle-field-visibility', $event)"
@ -122,6 +124,10 @@ export default {
required: false,
default: false,
},
allFieldsInTable: {
type: Array,
required: true,
},
},
data() {
return {

View file

@ -0,0 +1,28 @@
<template>
<div class="control__elements">
<textarea
ref="input"
v-model="value"
type="text"
class="input field-long-text margin-bottom-2"
:disabled="true"
/>
<a
v-if="rowIsCreated"
class="button button--ghost"
:class="{ 'button--loading': generating }"
@click="generate()"
>{{ $t('rowEditFieldAI.generate') }}</a
>
<div v-else>{{ $t('rowEditFieldAI.createRowBefore') }}</div>
</div>
</template>
<script>
import rowEditField from '@baserow/modules/database/mixins/rowEditField'
import fieldAI from '@baserow/modules/database/mixins/fieldAI'
export default {
mixins: [rowEditField, fieldAI],
}
</script>

View file

@ -53,6 +53,7 @@
:view="view"
:table="table"
:database="database"
:all-fields-in-table="allFieldsInTable"
@field-updated="$emit('field-updated', $event)"
@field-deleted="$emit('field-deleted')"
@order-fields="$emit('order-fields', $event)"
@ -77,6 +78,7 @@
:view="view"
:table="table"
:database="database"
:all-fields-in-table="allFieldsInTable"
@field-updated="$emit('field-updated', $event)"
@field-deleted="$emit('field-deleted')"
@toggle-field-visibility="$emit('toggle-field-visibility', $event)"
@ -109,6 +111,7 @@
ref="createFieldContext"
:table="table"
:view="view"
:all-fields-in-table="allFieldsInTable"
@field-created="$emit('field-created', $event)"
></CreateFieldContext>
</div>
@ -118,7 +121,7 @@
:row="row"
:table="table"
:database="database"
:fields="fields"
:fields="allFieldsInTable"
:read-only="readOnly"
></RowEditModalSidebar>
</template>
@ -157,7 +160,7 @@ export default {
required: false,
default: null,
},
fields: {
allFieldsInTable: {
type: Array,
required: true,
},

View file

@ -20,6 +20,7 @@
:table="table"
:view="view"
:field="field"
:all-fields-in-table="allFieldsInTable"
@update="$emit('field-updated', $event)"
@delete="$emit('field-deleted')"
>
@ -43,6 +44,7 @@
:field="field"
:value="row['field_' + field.id]"
:read-only="readOnly"
:row-is-created="!!row.id"
@update="update"
@refresh-row="$emit('refresh-row', row)"
/>
@ -100,6 +102,15 @@ export default {
type: Boolean,
required: true,
},
allFieldsInTable: {
type: Array,
required: true,
},
rowIsCreated: {
type: Boolean,
required: false,
default: () => true,
},
},
methods: {
getFieldComponent(type) {

View file

@ -26,6 +26,7 @@
:database="database"
:sortable="sortable && fieldIsSortable(field)"
:can-modify-fields="canModifyFields"
:all-fields-in-table="allFieldsInTable"
@field-updated="$emit('field-updated', $event)"
@field-deleted="$emit('field-deleted')"
@update="$emit('update', $event)"
@ -88,6 +89,10 @@ export default {
type: Object,
required: true,
},
allFieldsInTable: {
type: Array,
required: true,
},
},
methods: {
fieldIsSortable(field) {

View file

@ -44,6 +44,7 @@
:database="database"
:table="table"
:sortable="false"
:all-fields-in-table="allFields"
:visible-fields="allFields"
:can-modify-fields="false"
@created="createRow"

View file

@ -6,6 +6,7 @@
:view="view"
:fields="disabledFields"
:enabled-fields="enabledFields"
:all-fields-in-table="fields"
:read-only="
readOnly ||
!$hasPermission(

View file

@ -68,6 +68,7 @@
ref="createFieldContext"
:table="table"
:view="view"
:all-fields-in-table="allFieldsInTable"
@field-created="$event.callback()"
></CreateFieldContext>
</div>
@ -123,6 +124,10 @@ export default {
type: Boolean,
required: true,
},
allFieldsInTable: {
type: Array,
required: true,
},
},
computed: {
modeType() {

View file

@ -86,6 +86,7 @@
:visible-fields="cardFields"
:hidden-fields="hiddenFields"
:show-hidden-fields="showHiddenFieldsInRowModal"
:all-fields-in-table="fields"
@toggle-hidden-fields-visibility="
showHiddenFieldsInRowModal = !showHiddenFieldsInRowModal
"
@ -101,7 +102,7 @@
:database="database"
:table="table"
:view="view"
:fields="fields"
:all-fields-in-table="fields"
:primary-is-sortable="true"
:visible-fields="cardFields"
:hidden-fields="hiddenFields"

View file

@ -156,6 +156,15 @@
:max-height-if-outside-viewport="true"
>
<ul v-show="isMultiSelectActive" class="context__menu">
<component
:is="contextItemComponent"
v-for="(contextItemComponent, index) in getMultiSelectContextItems()"
:key="index"
:field="getSelectedField()"
:rows="getSelectedRows()"
:store-prefix="storePrefix"
@click=";[$refs.rowContext.hide()]"
></component>
<li class="context__menu-item">
<a
class="context__menu-item-link"
@ -297,7 +306,7 @@
:database="database"
:table="table"
:view="view"
:fields="fields"
:all-fields-in-table="fields"
:visible-fields="allVisibleFields"
:hidden-fields="hiddenFields"
:rows="allRows"
@ -1177,6 +1186,7 @@ export default {
rowId: row.id,
fieldIndex,
})
this.$refs.rowContext.hide()
},
/**
* Called when mouse hovers over a GridViewCell component.
@ -1425,7 +1435,6 @@ export default {
}
this.$store.dispatch('toast/setPasting', true)
try {
await this.$store.dispatch(
this.storePrefix + 'view/grid/updateDataIntoCells',
@ -1531,6 +1540,38 @@ export default {
height
)
},
/**
* Called when the user right clicks after selecting multiple cells.
* Shows the context menu with the appropriate options.
*/
getMultiSelectContextItems() {
const selectedFields = this.$store.getters[
this.storePrefix + 'view/grid/getSelectedFields'
](this.fields)
if (selectedFields.length === 1) {
return this.$registry
.get('field', selectedFields[0].type)
.getGridViewContextItemsOnCellsSelection()
} else {
return []
}
},
/**
* Returns the selected field if only one field is selected, otherwise returns null.
*/
getSelectedField() {
const selectedFields = this.$store.getters[
this.storePrefix + 'view/grid/getSelectedFields'
](this.fields)
return selectedFields.length === 1 ? selectedFields[0] : null
},
/**
* Returns the selected rows if any rows are selected, otherwise returns an empty array.
*/
getSelectedRows() {
return this.$store.getters[this.storePrefix + 'view/grid/getSelectedRows']
},
},
}
</script>

View file

@ -45,6 +45,7 @@
:table="table"
:view="view"
:field="field"
:all-fields-in-table="allFieldsInTable"
@update="$emit('refresh', $event)"
@delete="$emit('refresh')"
>
@ -98,6 +99,7 @@
:table="table"
:view="view"
:from-field="field"
:all-fields-in-table="allFieldsInTable"
@field-created="$emit('field-created', $event)"
@move-field="moveField($event)"
></InsertFieldContext>
@ -124,6 +126,7 @@
ref="duplicateFieldModal"
:table="table"
:from-field="field"
:all-fields-in-table="allFieldsInTable"
@field-created="$emit('field-created', $event)"
@move-field="moveField($event)"
></DuplicateFieldModal>
@ -288,6 +291,10 @@ export default {
type: Boolean,
required: true,
},
allFieldsInTable: {
type: Array,
required: true,
},
},
data() {
return {

View file

@ -38,6 +38,7 @@
:table="table"
:view="view"
:field="field"
:all-fields-in-table="allFieldsInTable"
:filters="view.filters"
:include-field-width-handles="includeFieldWidthHandles"
:read-only="readOnly"
@ -71,6 +72,7 @@
ref="createFieldContext"
:table="table"
:view="view"
:all-fields-in-table="allFieldsInTable"
@field-created="$emit('field-created', $event)"
@shown="onShownCreateFieldContext"
></CreateFieldContext>

View file

@ -0,0 +1,30 @@
<template functional>
<div
v-if="!props.value || $options.methods.isGenerating(parent, props)"
class="grid-view__cell"
>
<div class="grid-field-button">
<a
class="button button--tiny button--ghost"
:disabled="!$options.methods.isModelAvailable(parent, props)"
:class="{
'button--loading': $options.methods.isGenerating(parent, props),
}"
>
<i18n path="functionalGridViewFieldAI.generate" tag="span" />
</a>
</div>
</div>
<div v-else class="grid-view__cell grid-field-long-text__cell">
<div class="grid-field-long-text">{{ props.value }}</div>
</div>
</template>
<script>
import gridFieldAI from '@baserow/modules/database/mixins/gridFieldAI'
export default {
name: 'FunctionalGridViewFieldAI',
mixins: [gridFieldAI],
}
</script>

View file

@ -0,0 +1,60 @@
<template>
<div
v-if="!value || (!opened && generating)"
ref="cell"
class="grid-view__cell active"
>
<div class="grid-field-button">
<button
class="button button--tiny button--ghost"
:disabled="!modelAvailable"
:class="{ 'button--loading': generating }"
@click="generate()"
>
{{ $t('gridViewFieldAI.generate') }}
</button>
</div>
</div>
<div
v-else
ref="cell"
class="grid-view__cell grid-field-long-text__cell active"
:class="{ editing: opened }"
@keyup.enter="opened = true"
>
<div v-if="!opened" class="grid-field-long-text">{{ value }}</div>
<template v-else>
<div class="grid-field-long-text__textarea">
{{ value }}
</div>
<div style="background-color: #fff; padding: 8px">
<button
class="button button--link"
:disabled="!modelAvailable"
:class="{ 'button--loading': generating }"
@click.prevent.stop="generate()"
>
<i class="button__icon iconoir-magic-wand"></i>
{{ $t('gridViewFieldAI.regenerate') }}
</button>
</div>
</template>
</div>
</template>
<script>
import gridField from '@baserow/modules/database/mixins/gridField'
import gridFieldInput from '@baserow/modules/database/mixins/gridFieldInput'
import gridFieldAI from '@baserow/modules/database/mixins/gridFieldAI'
export default {
mixins: [gridField, gridFieldInput, gridFieldAI],
methods: {
save() {
this.opened = false
this.editing = false
this.afterSave()
},
},
}
</script>

View file

@ -0,0 +1,71 @@
<template>
<li class="context__menu-item">
<a
class="context__menu-item-link"
:class="{ disabled: !modelAvailable }"
@click.prevent.stop=";[generateAIFieldValues()]"
>
<i class="context__menu-item-icon iconoir-magic-wand"></i>
{{ $t('gridView.generateCellsValues') }}
</a>
</li>
</template>
<script>
import FieldService from '@baserow/modules/database/services/field'
import { notifyIf } from '@baserow/modules/core/utils/error'
export default {
props: {
field: {
type: Object,
required: true,
},
rows: {
type: Array,
required: true,
},
storePrefix: {
type: String,
required: true,
},
},
computed: {
modelAvailable() {
const aIModels =
this.$store.getters['settings/get'].generative_ai[
this.field.ai_generative_ai_type
] || []
return (
this.$registry.get('field', this.field.type).isEnabled() &&
aIModels.includes(this.field.ai_generative_ai_model)
)
},
},
methods: {
async generateAIFieldValues($event) {
if (!this.modelAvailable) {
return
}
const rowIds = this.rows.map((row) => row.id)
const fieldId = this.field.id
this.$store.dispatch(
this.storePrefix + 'view/grid/setPendingFieldOperations',
{ fieldId, rowIds, value: true }
)
try {
await FieldService(this.$client).generateAIFieldValues(fieldId, rowIds)
} catch (error) {
this.$store.dispatch(
this.storePrefix + 'view/grid/setPendingFieldOperations',
{ fieldId, rowIds, value: false }
)
notifyIf(error, 'field')
}
this.$emit('click', $event)
},
},
}
</script>

View file

@ -0,0 +1,30 @@
import { DataProviderType } from '@baserow/modules/core/dataProviderTypes'
export class FieldsDataProviderType extends DataProviderType {
static getType() {
return 'fields'
}
get name() {
return this.app.i18n.t('dataProviderTypes.fieldsName')
}
getDataContent(applicationContext) {
return ''
}
getDataSchema(applicationContext) {
return {
type: 'object',
properties: Object.fromEntries(
(applicationContext.fields || []).map((field) => [
`field_${field.id}`,
{
title: field.name,
type: 'string',
},
])
),
}
}
}

View file

@ -30,6 +30,7 @@ import FieldLinkRowSubForm from '@baserow/modules/database/components/field/Fiel
import FieldSelectOptionsSubForm from '@baserow/modules/database/components/field/FieldSelectOptionsSubForm'
import FieldCollaboratorSubForm from '@baserow/modules/database/components/field/FieldCollaboratorSubForm'
import FieldPasswordSubForm from '@baserow/modules/database/components/field/FieldPasswordSubForm'
import FieldAISubForm from '@baserow/modules/database/components/field/FieldAISubForm'
import GridViewFieldText from '@baserow/modules/database/components/view/grid/fields/GridViewFieldText'
import GridViewFieldLongText from '@baserow/modules/database/components/view/grid/fields/GridViewFieldLongText'
@ -52,6 +53,7 @@ import GridViewFieldUUID from '@baserow/modules/database/components/view/grid/fi
import GridViewFieldAutonumber from '@baserow/modules/database/components/view/grid/fields/GridViewFieldAutonumber'
import GridViewFieldLastModifiedBy from '@baserow/modules/database/components/view/grid/fields/GridViewFieldLastModifiedBy'
import GridViewFieldPassword from '@baserow/modules/database/components/view/grid/fields/GridViewFieldPassword'
import GridViewFieldAI from '@baserow/modules/database/components/view/grid/fields/GridViewFieldAI'
import FunctionalGridViewFieldText from '@baserow/modules/database/components/view/grid/fields/FunctionalGridViewFieldText'
import FunctionalGridViewFieldDuration from '@baserow/modules/database/components/view/grid/fields/FunctionalGridViewFieldDuration'
@ -72,6 +74,7 @@ import FunctionalGridViewFieldUUID from '@baserow/modules/database/components/vi
import FunctionalGridViewFieldAutonumber from '@baserow/modules/database/components/view/grid/fields/FunctionalGridViewFieldAutonumber'
import FunctionalGridViewFieldLastModifiedBy from '@baserow/modules/database/components/view/grid/fields/FunctionalGridViewFieldLastModifiedBy'
import FunctionalGridVIewFieldPassword from '@baserow/modules/database/components/view/grid/fields/FunctionalGridVIewFieldPassword.vue'
import FunctionalGridViewFieldAI from '@baserow/modules/database/components/view/grid/fields/FunctionalGridViewFieldAI'
import RowEditFieldText from '@baserow/modules/database/components/row/RowEditFieldText'
import RowEditFieldLongText from '@baserow/modules/database/components/row/RowEditFieldLongText'
@ -94,6 +97,7 @@ import RowEditFieldUUID from '@baserow/modules/database/components/row/RowEditFi
import RowEditFieldAutonumber from '@baserow/modules/database/components/row/RowEditFieldAutonumber'
import RowEditFieldLastModifiedBy from '@baserow/modules/database/components/row/RowEditFieldLastModifiedBy'
import RowEditFieldPassword from '@baserow/modules/database/components/row/RowEditFieldPassword'
import RowEditFieldAI from '@baserow/modules/database/components/row/RowEditFieldAI'
import RowCardFieldBoolean from '@baserow/modules/database/components/card/RowCardFieldBoolean'
import RowCardFieldDate from '@baserow/modules/database/components/card/RowCardFieldDate'
@ -134,6 +138,8 @@ import FormViewFieldMultipleLinkRow from '@baserow/modules/database/components/v
import FormViewFieldMultipleSelectCheckboxes from '@baserow/modules/database/components/view/form/FormViewFieldMultipleSelectCheckboxes'
import FormViewFieldSingleSelectRadios from '@baserow/modules/database/components/view/form/FormViewFieldSingleSelectRadios'
import GridViewFieldAIGenerateValuesContextItem from '@baserow/modules/database/components/view/grid/fields/GridViewFieldAIGenerateValuesContextItem'
import { trueValues } from '@baserow/modules/core/utils/constants'
import {
getDateMomentFormat,
@ -197,6 +203,16 @@ export class FieldType extends Registerable {
)
}
/**
* This method generates the context menu options for actions that can be performed on
* more selected cells within the same field. These options appear in the grid view
* when the user right-clicks on multiple cells.
* @param field The field object.
*/
getGridViewContextItemsOnCellsSelection(field) {
return []
}
/**
* This functional component should represent an unselect field cell related to the
* value of this type. It will only be used in the grid view and is only for fast
@ -756,6 +772,13 @@ export class FieldType extends Registerable {
parseInputValue(field, value) {
return value
}
/**
* Indicates whether it's possible to select the field type when creating or updating the field.
*/
isEnabled() {
return true
}
}
export class TextFieldType extends FieldType {
@ -4137,3 +4160,88 @@ export class PasswordFieldType extends FieldType {
return RowHistoryFieldPassword
}
}
export class AIFieldType extends FieldType {
static getType() {
return 'ai'
}
getIconClass() {
return 'iconoir-magic-wand'
}
getName() {
const { i18n } = this.app
return i18n.t('fieldType.ai')
}
getGridViewFieldComponent() {
return GridViewFieldAI
}
getFunctionalGridViewFieldComponent() {
return FunctionalGridViewFieldAI
}
getRowEditFieldComponent(field) {
return RowEditFieldAI
}
getCardComponent() {
return RowCardFieldText
}
getRowHistoryEntryComponent() {
return RowHistoryFieldText
}
getFormComponent() {
return FieldAISubForm
}
getFormViewFieldComponents(field) {
return {}
}
getEmptyValue(field) {
return null
}
getSort(name, order) {
return (a, b) => {
const stringA = a[name] === null ? '' : '' + a[name]
const stringB = b[name] === null ? '' : '' + b[name]
return collatedStringCompare(stringA, stringB, order)
}
}
getDocsDataType(field) {
return 'string'
}
getDocsDescription(field) {
return '@TODO'
}
getDocsRequestExample(field) {
return 'string'
}
getContainsFilterFunction() {
return genericContainsFilter
}
getContainsWordFilterFunction(field) {
return genericContainsWordFilter
}
getGridViewContextItemsOnCellsSelection(field) {
return [GridViewFieldAIGenerateValuesContextItem]
}
isEnabled() {
const { store } = this.app
return Object.keys(store.getters['settings/get'].generative_ai).length > 0
}
}

View file

@ -599,11 +599,14 @@
"deleteRow": "Delete row",
"deleteRows": "Delete rows",
"copyCells": "Copy cells",
"generateCellsValues": "Generate values with AI",
"rowCount": "No rows | 1 row | {count} rows",
"hiddenRowsInsertedTitle": "Rows added",
"hiddenRowsInsertedMessage": "{number} newly added rows have been added, but are not visible because of the active filters.",
"tooManyItemsTitle": "Too many items",
"tooManyItemsDescription": "It's not possible to update more than {limit} rows at once, so we've only updated the first."
"tooManyItemsDescription": "It's not possible to update more than {limit} rows at once, so we've only updated the first.",
"AIValuesGenerationErrorTitle": "AI value generation failed",
"AIValuesGenerationErrorMessage": "Please check your API_KEY and verify the selected model."
},
"gridViewFieldFile": {
"dropHere": "Drop here",
@ -829,5 +832,27 @@
"passwordSet": "The password was set",
"passwordUpdated": "The password was updated",
"passwordDeleted": "The password was deleted"
},
"dataProviderTypes": {
"fieldsName": "Fields"
},
"functionalGridViewFieldAI": {
"generate": "Generate"
},
"gridViewFieldAI": {
"generate": "Generate",
"regenerate": "Re-Generate"
},
"fieldAISubForm": {
"AIType": "AI Type",
"AIModel": "AI Model",
"prompt": "Prompt"
},
"rowEditFieldAI": {
"generate": "Generate",
"createRowBefore": "The AI value can be generated after the row has been created."
},
"rowCardFieldAI": {
"generate": "Generate"
}
}

View file

@ -0,0 +1,40 @@
import FieldService from '@baserow/modules/database/services/field'
import { notifyIf } from '@baserow/modules/core/utils/error'
export default {
data() {
return {
generating: false,
}
},
computed: {
modelAvailable() {
const aIModels =
this.$store.getters['settings/get'].generative_ai[
this.field.ai_generative_ai_type
] || []
return (
this.$registry.get('field', this.field.type).isEnabled() &&
aIModels.includes(this.field.ai_generative_ai_model)
)
},
},
watch: {
value() {
this.generating = false
},
},
methods: {
async generate() {
this.generating = true
try {
await FieldService(this.$client).generateAIFieldValues(this.field.id, [
this.$parent.row.id,
])
} catch (error) {
notifyIf(error, 'field')
this.generating = false
}
},
},
}

View file

@ -15,6 +15,9 @@ export default {
},
primary: {
type: Boolean,
},
allFieldsInTable: {
type: Array,
required: true,
},
},

View file

@ -0,0 +1,50 @@
import FieldService from '@baserow/modules/database/services/field'
import { notifyIf } from '@baserow/modules/core/utils/error'
export default {
computed: {
// Indicates if the cell is currently being generated together with other cells in
// bulk via the apposite grid-view menu.
generating() {
return this.isGenerating(this.$parent, this.$props)
},
modelAvailable() {
return this.isModelAvailable(this.$parent, this.$props)
},
},
methods: {
isGenerating(parent, props) {
return parent.row._.pendingFieldOps?.find(
(fieldName) => fieldName === `field_${props.field.id}`
)
},
isModelAvailable(parent, props) {
const aIModels =
parent.$store.getters['settings/get'].generative_ai[
props.field.ai_generative_ai_type
] || []
return (
parent.$registry.get('field', props.field.type).isEnabled() &&
aIModels.includes(props.field.ai_generative_ai_model)
)
},
async generate() {
const rowId = this.$parent.row.id
this.$store.dispatch(
this.storePrefix + 'view/grid/setPendingFieldOperations',
{ fieldId: this.field.id, rowIds: [rowId], value: true }
)
try {
await FieldService(this.$client).generateAIFieldValues(this.field.id, [
rowId,
])
} catch (error) {
notifyIf(error, 'field')
this.$store.dispatch(
this.storePrefix + 'view/grid/setPendingFieldOperations',
{ fieldId: this.field.id, rowIds: [rowId], value: false }
)
}
},
},
}

View file

@ -32,6 +32,11 @@ export default {
required: false,
default: true,
},
rowIsCreated: {
type: Boolean,
required: false,
default: () => true,
},
},
methods: {
/**

View file

@ -32,6 +32,7 @@ import {
UUIDFieldType,
AutonumberFieldType,
PasswordFieldType,
AIFieldType,
} from '@baserow/modules/database/fieldTypes'
import {
EqualViewFilterType,
@ -257,6 +258,7 @@ import {
FormSubmittedNotificationType,
} from '@baserow/modules/database/notificationTypes'
import { HistoryRowModalSidebarType } from '@baserow/modules/database/rowModalSidebarTypes'
import { FieldsDataProviderType } from '@baserow/modules/database/dataProviderTypes'
import en from '@baserow/modules/database/locales/en.json'
import fr from '@baserow/modules/database/locales/fr.json'
@ -465,6 +467,7 @@ export default (context) => {
app.$registry.register('field', new UUIDFieldType(context))
app.$registry.register('field', new AutonumberFieldType(context))
app.$registry.register('field', new PasswordFieldType(context))
app.$registry.register('field', new AIFieldType(context))
app.$registry.register('importer', new CSVImporterType(context))
app.$registry.register('importer', new PasteImporterType(context))
@ -709,6 +712,11 @@ export default (context) => {
app.$registry.register('formViewMode', new FormViewFormModeType(context))
app.$registry.register(
'databaseDataProvider',
new FieldsDataProviderType(context)
)
// notifications
app.$registry.register(
'notification',

View file

@ -214,6 +214,24 @@ export const registerRealtimeEvents = (realtime) => {
}
})
realtime.registerEvent(
'rows_ai_values_generation_error',
async (context, data) => {
const { app } = context
for (const viewType of Object.values(app.$registry.getAll('view'))) {
await viewType.AIValuesGenerationError(
context,
data.table_id,
data.field_id,
data.row_ids,
data.error,
'page/'
)
}
}
)
realtime.registerEvent('rows_deleted', (context, data) => {
const { app, store } = context
for (const viewType of Object.values(app.$registry.getAll('view'))) {

View file

@ -46,5 +46,11 @@ export default (client) => {
config
)
},
generateAIFieldValues(fieldId, rowIds) {
return client.post(
`/database/fields/${fieldId}/generate-ai-field-values/`,
{ row_ids: rowIds }
)
},
}
}

View file

@ -49,6 +49,10 @@ export function populateRow(row, metadata = {}) {
// between cells.
selected: false,
selectedFieldId: -1,
// Contains the specific field ids that are in a loading state. This is for
// example used for fields that use a background worker to compute the value
// like the AI field.
pendingFieldOps: [],
}
return row
}
@ -412,6 +416,17 @@ export const mutations = {
if (metadata) {
existingRowState._.metadata = metadata
}
// Remove every pending AI field if a value is provided for it.
if (existingRowState._?.pendingFieldOps?.length > 0) {
const newFieldKeys = new Set(
Object.keys(values).filter((key) => values[key])
)
existingRowState._.pendingFieldOps =
existingRowState._.pendingFieldOps.filter(
(key) => !newFieldKeys.has(key)
)
}
}
},
UPDATE_ROW_VALUES(state, { row, values }) {
@ -586,6 +601,20 @@ export const mutations = {
}
})
},
SET_PENDING_FIELD_OPERATIONS(state, { fieldId, rowIds, value }) {
const key = `field_${fieldId}`
state.rows.forEach((row) => {
if (rowIds.includes(row.id)) {
if (value) {
row._.pendingFieldOps.push(key)
} else {
row._.pendingFieldOps = row._.pendingFieldOps.filter(
(fieldName) => fieldName !== key
)
}
}
})
},
}
// Contains the info needed for the delayed scroll top action.
@ -2260,6 +2289,7 @@ export const actions = {
jsonData,
rowIndex,
fieldIndex,
selectUpdatedCells = true,
}
) {
const copiedRowsCount = textData.length
@ -2324,7 +2354,7 @@ export const actions = {
rowTailIndex = rowTailIndex + newRowsCount
}
if (!isSingleCellCopied) {
if (!isSingleCellCopied && selectUpdatedCells) {
// Expand the selection of the multiple select to the cells that we're going to
// paste in, so the user can see which values have been updated. This is because
// it could be that there are more or less values in the clipboard compared to
@ -2929,6 +2959,26 @@ export const actions = {
fieldIndex: minFieldIndex,
})
},
/**
* Add the fieldId to the list of pending field operations for the given rowIds.
* This is used to show a loading spinner when a field is being updated. For example,
* the AI field type uses this to show a spinner when the AI values are being
* generated in a background task.
*/
setPendingFieldOperations({ commit }, { fieldId, rowIds, value = true }) {
commit('SET_PENDING_FIELD_OPERATIONS', { fieldId, rowIds, value })
},
AIValuesGenerationError({ commit, dispatch }, { fieldId, rowIds }) {
commit('SET_PENDING_FIELD_OPERATIONS', { fieldId, rowIds, value: false })
dispatch(
'toast/error',
{
title: this.$i18n.t('gridView.AIValuesGenerationErrorTitle'),
message: this.$i18n.t('gridView.AIValuesGenerationErrorMessage'),
},
{ root: true }
)
},
}
export const getters = {
@ -3144,6 +3194,23 @@ export const getters = {
)
}
},
getSelectedFields: (state, getters) => (fields) => {
const [minField, maxField] = getters.getMultiSelectFieldIndexSorted
const selectedFields = []
const fieldMap = fields.reduce((acc, field) => {
acc[field.id] = field
return acc
}, {})
for (let i = minField; i <= maxField; i++) {
const fieldId = getters.getFieldIdByIndex(i, fields)
if (fieldId !== -1) {
selectedFields.push(fieldMap[fieldId])
}
}
return selectedFields
},
getAllFieldAggregationData(state) {
return state.fieldAggregationData
},

View file

@ -259,6 +259,12 @@ export class ViewType extends Registerable {
*/
rowUpdated(context, tableId, fields, row, values, metadata, storePrefix) {}
/**
* Event that is called when something went wrong while generating AI values
* for a field. This can be used to show an error message to the user.
*/
AIValuesGenerationError(context, tableId, fieldId, rowIds, error) {}
/**
* Event that is called when a row is deleted from an outside source, so for example
* via a real time event by another user. It can be used to check if data in an store
@ -630,6 +636,27 @@ export class GridViewType extends ViewType {
})
}
}
AIValuesGenerationError(
context,
tableId,
fieldId,
rowIds,
error,
storePrefix = ''
) {
if (this.isCurrentView(context.store, tableId)) {
context.store.dispatch(
storePrefix + 'view/grid/AIValuesGenerationError',
{
fieldId,
rowIds,
error,
},
{ root: true }
)
}
}
}
/**

View file

@ -333,6 +333,15 @@ const mockedFields = {
type: 'password',
testingRowData: [null, true, 'test'],
},
ai: {
id: 26,
name: 'ai',
order: 26,
primary: false,
table_id: 42,
type: 'ai',
testingRowData: [null, 'Generated: hello!'],
},
}
const valuesToCall = [null, undefined]