mirror of
https://gitlab.com/bramw/baserow.git
synced 2025-04-12 16:28:06 +00:00
451 lines
14 KiB
Python
451 lines
14 KiB
Python
from contextlib import contextmanager
|
|
from typing import Callable, Dict, Optional, Tuple, Type, Union
|
|
|
|
from django.utils.encoding import force_str
|
|
|
|
from rest_framework import serializers, status
|
|
from rest_framework.exceptions import APIException
|
|
from rest_framework.request import Request
|
|
from rest_framework.serializers import ModelSerializer
|
|
|
|
from baserow.core.exceptions import InstanceTypeDoesNotExist
|
|
|
|
from .exceptions import RequestBodyValidationException
|
|
|
|
ErrorTupleType = Tuple[str, int, str]
|
|
ExceptionMappingType = Dict[
|
|
Type[Exception],
|
|
Union[
|
|
str,
|
|
ErrorTupleType,
|
|
Callable[
|
|
[
|
|
Exception,
|
|
],
|
|
Optional[Union[str, ErrorTupleType]],
|
|
],
|
|
],
|
|
]
|
|
|
|
|
|
@contextmanager
|
|
def map_exceptions(mapping: ExceptionMappingType):
|
|
"""
|
|
This utility function simplifies mapping uncaught exceptions to a standard api
|
|
response exception.
|
|
|
|
Example:
|
|
with map_api_exceptions({ SomeException: 'ERROR_1' }):
|
|
raise SomeException('This is a test')
|
|
|
|
HTTP/1.1 400
|
|
{
|
|
"error": "ERROR_1",
|
|
"detail": "This is a test"
|
|
}
|
|
|
|
Example 2:
|
|
with map_api_exceptions({ SomeException: ('ERROR_1', 404, 'Other message') }):
|
|
raise SomeException('This is a test')
|
|
|
|
HTTP/1.1 404
|
|
{
|
|
"error": "ERROR_1",
|
|
"detail": "Other message"
|
|
}
|
|
|
|
Example 3:
|
|
with map_api_exceptions(
|
|
{
|
|
SomeException: lambda e: ('ERROR_1', 404, 'Conditional Error')
|
|
if "something" in str(e)
|
|
else None
|
|
}
|
|
):
|
|
raise SomeException('something')
|
|
|
|
HTTP/1.1 404
|
|
{
|
|
"error": "ERROR_1",
|
|
"detail": "Conditional Error"
|
|
}
|
|
|
|
Example 4:
|
|
with map_api_exceptions(
|
|
{
|
|
SomeException: lambda e: ('ERROR_1', 404, 'Conditional Error')
|
|
if "something" in str(e)
|
|
else None
|
|
}
|
|
):
|
|
raise SomeException('doesnt match')
|
|
|
|
# SomeException will be thrown directly if the provided callable returns None.
|
|
"""
|
|
|
|
from baserow.api.registries import api_exception_registry
|
|
|
|
registered_exceptions = api_exception_registry.get_all()
|
|
for ex in registered_exceptions:
|
|
mapping[ex.exception_class] = ex.exception_error
|
|
|
|
try:
|
|
yield
|
|
except tuple(mapping.keys()) as e:
|
|
value = _search_up_class_hierarchy_for_mapping(e, mapping)
|
|
status_code = status.HTTP_400_BAD_REQUEST
|
|
detail = ""
|
|
|
|
if callable(value):
|
|
value = value(e)
|
|
if value is None:
|
|
raise e
|
|
if isinstance(value, str):
|
|
error = value
|
|
if isinstance(value, tuple):
|
|
error = value[0]
|
|
if len(value) > 1 and value[1] is not None:
|
|
status_code = value[1]
|
|
if len(value) > 2 and value[2] is not None:
|
|
detail = value[2].format(e=e)
|
|
|
|
exc = APIException({"error": error, "detail": detail})
|
|
exc.status_code = status_code
|
|
|
|
raise exc
|
|
|
|
|
|
def _search_up_class_hierarchy_for_mapping(e, mapping):
|
|
for clazz in e.__class__.mro():
|
|
value = mapping.get(clazz)
|
|
if value:
|
|
return value
|
|
return None
|
|
|
|
|
|
def validate_data(
|
|
serializer_class,
|
|
data,
|
|
partial=False,
|
|
exception_to_raise=RequestBodyValidationException,
|
|
many=False,
|
|
return_validated=False,
|
|
):
|
|
"""
|
|
Validates the provided data via the provided serializer class. If the data doesn't
|
|
match with the schema of the serializer an api exception containing more detailed
|
|
information will be raised.
|
|
|
|
:param serializer_class: The serializer that must be used for validating.
|
|
:type serializer_class: Serializer
|
|
:param data: The data that needs to be validated.
|
|
:type data: dict
|
|
:param partial: Whether the data is a partial update.
|
|
:type partial: bool
|
|
:param many: Indicates whether the serializer should be constructed as a list.
|
|
:type many: bool
|
|
:param return_validated: Returns validated_data from DRF serializer
|
|
:type return_validated: bool
|
|
:return: The data after being validated by the serializer.
|
|
:rtype: dict
|
|
"""
|
|
|
|
def serialize_errors_recursive(error):
|
|
if isinstance(error, dict):
|
|
return {
|
|
key: serialize_errors_recursive(errors) for key, errors in error.items()
|
|
}
|
|
elif isinstance(error, list):
|
|
return [serialize_errors_recursive(errors) for errors in error]
|
|
else:
|
|
return {"error": force_str(error), "code": error.code}
|
|
|
|
serializer = serializer_class(data=data, partial=partial, many=many)
|
|
if not serializer.is_valid():
|
|
detail = serialize_errors_recursive(serializer.errors)
|
|
raise exception_to_raise(detail)
|
|
|
|
if return_validated:
|
|
return serializer.validated_data
|
|
|
|
return serializer.data
|
|
|
|
|
|
def validate_data_custom_fields(
|
|
type_name,
|
|
registry,
|
|
data,
|
|
base_serializer_class=None,
|
|
type_attribute_name="type",
|
|
partial=False,
|
|
allow_empty_type=False,
|
|
):
|
|
"""
|
|
Validates the provided data with the serializer generated by the registry based on
|
|
the provided type_name and provided base_serializer_class.
|
|
|
|
:param type_name: The type name of the type instance that is needed to generated
|
|
the serializer.
|
|
:type type_name: str
|
|
:param registry: The registry where to get the type instance from.
|
|
:type registry: Registry
|
|
:param data: The data that needs to be validated.
|
|
:type data: dict
|
|
:param base_serializer_class: The base serializer class that is used when
|
|
generating the serializer for validation.
|
|
:type base_serializer_class: ModelSerializer
|
|
:param type_attribute_name: The attribute key name that contains the type value.
|
|
:type type_attribute_name: str
|
|
:param partial: Whether the data is a partial update.
|
|
:type partial: bool
|
|
:raises RequestBodyValidationException: When the type is not a valid choice.
|
|
:return: The validated data.
|
|
:rtype: dict
|
|
"""
|
|
|
|
if not type_name and allow_empty_type:
|
|
serializer_class = base_serializer_class
|
|
else:
|
|
try:
|
|
type_instance = registry.get(type_name)
|
|
except InstanceTypeDoesNotExist:
|
|
# If the provided type name doesn't exist we will raise a machine
|
|
# readable validation error.
|
|
raise RequestBodyValidationException(
|
|
{
|
|
type_attribute_name: [
|
|
{
|
|
"error": f'"{type_name}" is not a valid choice.',
|
|
"code": "invalid_choice",
|
|
}
|
|
]
|
|
}
|
|
)
|
|
else:
|
|
serializer_kwargs = {
|
|
"base_class": base_serializer_class,
|
|
# We want the request serializer as we are validating date from a
|
|
# request
|
|
"request_serializer": True,
|
|
}
|
|
serializer_class = type_instance.get_serializer_class(**serializer_kwargs)
|
|
|
|
return validate_data(serializer_class, data, partial=partial)
|
|
|
|
|
|
def get_request(args):
|
|
"""
|
|
A small helper function that checks if the request is in the args and returns that
|
|
request.
|
|
|
|
:param args: A list containing the original arguments of the called view method.
|
|
:type args: list
|
|
:raises ValueError: When the request has not been found in the args.
|
|
:return: The extracted request object.
|
|
:rtype: Request
|
|
"""
|
|
|
|
if len(args) < 2 or not isinstance(args[1], Request):
|
|
raise ValueError("There must be a request in the args.")
|
|
|
|
return args[1]
|
|
|
|
|
|
def type_from_data_or_registry(
|
|
data, registry, model_instance, type_attribute_name="type"
|
|
):
|
|
"""
|
|
Returns the type in the provided data else the type will be returned via the
|
|
registry.
|
|
|
|
:param data: The data that might contains the type name.
|
|
:type data: dict
|
|
:param registry: The registry where to get the type instance from if not provided in
|
|
the data.
|
|
:type registry: Registry
|
|
:param model_instance: The model instance we want to know the type from if not
|
|
provided in the data.
|
|
:type model_instance: Model
|
|
:param type_attribute_name: The expected type attribute name in the data.
|
|
:type type_attribute_name: str
|
|
:return: The extracted type.
|
|
:rtype: str
|
|
"""
|
|
|
|
if type_attribute_name in data:
|
|
return data[type_attribute_name]
|
|
else:
|
|
return registry.get_by_model(model_instance.specific_class).type
|
|
|
|
|
|
def get_serializer_class(
|
|
model,
|
|
field_names,
|
|
field_overrides=None,
|
|
base_class=None,
|
|
meta_ref_name=None,
|
|
required_fields=None,
|
|
base_mixins=None,
|
|
):
|
|
"""
|
|
Generates a model serializer based on the provided field names and field overrides.
|
|
|
|
:param model: The model class that must be used for the ModelSerializer.
|
|
:type model: Model
|
|
:param field_names: The model field names that must be added to the serializer.
|
|
:type field_names: list
|
|
:param field_overrides: A dict containing field overrides where the key is the name
|
|
and the value must be a serializer Field.
|
|
:type field_overrides: dict
|
|
:param base_class: The class that must be extended.
|
|
:type base_class: ModelSerializer
|
|
:param meta_ref_name: Optionally a custom ref name can be set. If not provided,
|
|
then the class name of the model and base class are used.
|
|
:type meta_ref_name: str
|
|
:param required_fields: List of field names that should be present even when
|
|
performing partial validation.
|
|
:type required_fields: list[str]
|
|
:param mixins: An optional list of mixins that must be added to the serializer.
|
|
:type base_mixins: list[serializers.Serializer]
|
|
:return: The generated model serializer containing the provided fields.
|
|
:rtype: ModelSerializer
|
|
"""
|
|
|
|
model_ = model
|
|
|
|
if not field_overrides:
|
|
field_overrides = {}
|
|
|
|
if not meta_ref_name:
|
|
meta_ref_name = model_.__name__
|
|
|
|
if base_class:
|
|
meta_ref_name += base_class.__name__
|
|
|
|
if not base_class:
|
|
base_class = ModelSerializer
|
|
|
|
extends_meta = object
|
|
|
|
if hasattr(base_class, "Meta"):
|
|
extends_meta = getattr(base_class, "Meta")
|
|
field_names = list(extends_meta.fields) + list(field_names)
|
|
|
|
class Meta(extends_meta):
|
|
ref_name = meta_ref_name
|
|
model = model_
|
|
fields = list(field_names)
|
|
|
|
attrs = {"Meta": Meta}
|
|
|
|
if field_overrides:
|
|
attrs.update(field_overrides)
|
|
|
|
def validate(self, value):
|
|
if required_fields:
|
|
for field_name in required_fields:
|
|
if field_name not in value:
|
|
raise serializers.ValidationError(
|
|
{f"{field_name}": "This field is required."}
|
|
)
|
|
|
|
return value
|
|
|
|
attrs["validate"] = validate
|
|
mixins = base_mixins or []
|
|
return type(
|
|
str(model_.__name__ + "Serializer"),
|
|
(
|
|
*mixins,
|
|
base_class,
|
|
),
|
|
attrs,
|
|
)
|
|
|
|
|
|
class MappingSerializer:
|
|
"""
|
|
A placeholder class for the `MappingSerializerExtension` extension class.
|
|
"""
|
|
|
|
def __init__(self, component_name, mapping, name, many=False):
|
|
self.read_only = False
|
|
self.component_name = component_name
|
|
self.mapping = mapping
|
|
self.name = name
|
|
self.many = many
|
|
self.partial = False
|
|
|
|
|
|
class CustomFieldRegistryMappingSerializer:
|
|
"""
|
|
A placeholder class for the `CustomFieldRegistryMappingSerializerExtension`
|
|
extension class.
|
|
"""
|
|
|
|
def __init__(self, registry, base_class, many=False):
|
|
self.read_only = False
|
|
self.registry = registry
|
|
self.base_class = base_class
|
|
self.many = many
|
|
self.partial = False
|
|
|
|
|
|
class DiscriminatorCustomFieldsMappingSerializer:
|
|
"""
|
|
A placeholder class for the `DiscriminatorCustomFieldsMappingSerializerExtension`
|
|
extension class.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
registry=None,
|
|
base_class=None,
|
|
type_field_name="type",
|
|
many=False,
|
|
help_text=None,
|
|
request=False,
|
|
context=None,
|
|
):
|
|
self.read_only = False
|
|
self.registry = registry
|
|
self.base_class = base_class
|
|
self.type_field_name = type_field_name
|
|
self.many = many
|
|
self.help_text = help_text
|
|
self.partial = False
|
|
self.request = request
|
|
self.context = {} if context is None else context
|
|
|
|
# Trick spectacular into thinking we are not a customized list serializer so it
|
|
# doesn't attempt to use its own customized list serializer extension code which
|
|
# doesn't work with our custom extension
|
|
to_representation = serializers.ListSerializer.to_representation
|
|
|
|
|
|
class DiscriminatorMappingSerializer:
|
|
"""
|
|
A placeholder class for the `DiscriminatorMappingSerializerExtension` extension
|
|
class.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
component_name=None,
|
|
mapping=None,
|
|
type_field_name="type",
|
|
many=False,
|
|
context=None,
|
|
):
|
|
self.read_only = False
|
|
self.component_name = component_name
|
|
self.mapping = mapping
|
|
self.type_field_name = type_field_name
|
|
self.many = many
|
|
self.partial = False
|
|
self.context = {} if context is None else context
|
|
|
|
# Trick spectacular into thinking we are not a customized list serializer so it
|
|
# doesn't attempt to use its own customized list serializer extension code which
|
|
# doesn't work with our custom extension
|
|
to_representation = serializers.ListSerializer.to_representation
|