Source code for tim_common.marshmallow_dataclass

"""
This library allows the conversion of python 3.7's :mod:`dataclasses`
to :mod:`marshmallow` schemas.

It takes a python class, and generates a marshmallow schema for it.

Simple example::

    from marshmallow import Schema
    from marshmallow_dataclass import dataclass

    @dataclass
    class Point:
      x:float
      y:float

    point = Point(x=0, y=0)
    point_json = Point.Schema().dumps(point)

Full example::

    from marshmallow import Schema
    from dataclasses import field
    from marshmallow_dataclass import dataclass
    import datetime

    @dataclass
    class User:
      birth: datetime.date = field(metadata= {
        "required": True # A parameter to pass to marshmallow's field
      })
      website:str = field(metadata = {
        "marshmallow_field": marshmallow.fields.Url() # Custom marshmallow field
      })
      Schema: ClassVar[Type[Schema]] = Schema # For the type checker
"""
import dataclasses
import inspect
from enum import EnumMeta
from functools import lru_cache
from typing import (
    overload,
    Type,
    List,
    cast,
    Tuple,
    Optional,
    Any,
    Mapping,
    TypeVar,
    Callable,
)

import marshmallow
import marshmallow.validate
import typing_inspect

import tim_common.typing_inspect_ext as typing_inspect_ext

__all__ = ["dataclass", "add_schema", "class_schema", "field_for_schema", "NewType"]

import webargs.fields

from marshmallow.fields import Integer

NoneType = type(None)
_U = TypeVar("_U")

# Whitelist of dataclass members that will be copied to generated schema.
MEMBERS_WHITELIST: set[str] = {"Meta"}

# Max number of generated schemas that class_schema keeps of generated schemas. Removes duplicates.
MAX_CLASS_SCHEMA_CACHE_SIZE = 1024


# _cls should never be specified by keyword, so start it with an
# underscore.  The presence of _cls is used to detect if this
# decorator is being called with parameters or not.
[docs]def dataclass( _cls: type[_U] = None, *, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, base_schema: type[marshmallow.Schema] | None = None, ) -> type[_U] | Callable[[type[_U]], type[_U]]: """ This decorator does the same as dataclasses.dataclass, but also applies :func:`add_schema`. It adds a `.Schema` attribute to the class object :param base_schema: marshmallow schema used as a base class when deriving dataclass schema >>> @dataclass ... class Artist: ... name: str >>> Artist.Schema <class 'marshmallow.schema.Artist'> >>> from typing import ClassVar >>> from marshmallow import Schema >>> @dataclass(order=True) # preserve field order ... class Point: ... x:float ... y:float ... Schema: ClassVar[Type[Schema]] = Schema # For the type checker ... >>> Point.Schema().load({'x':0, 'y':0}) # This line can be statically type checked Point(x=0.0, y=0.0) """ # dataclass's typing doesn't expect it to be called as a function, so ignore type check dc = dataclasses.dataclass( # type: ignore _cls, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen ) if _cls is None: return lambda cls: add_schema(dc(cls), base_schema) return add_schema(dc, base_schema)
@overload def add_schema(_cls: type[_U]) -> type[_U]: ... @overload def add_schema( base_schema: type[marshmallow.Schema] = None, ) -> Callable[[type[_U]], type[_U]]: ... @overload def add_schema( _cls: type[_U], base_schema: type[marshmallow.Schema] = None ) -> type[_U]: ...
[docs]def add_schema(_cls=None, base_schema=None): """ This decorator adds a marshmallow schema as the 'Schema' attribute in a dataclass. It uses :func:`class_schema` internally. :param type cls: The dataclass to which a Schema should be added :param base_schema: marshmallow schema used as a base class when deriving dataclass schema >>> class BaseSchema(marshmallow.Schema): ... def on_bind_field(self, field_name, field_obj): ... field_obj.data_key = (field_obj.data_key or field_name).upper() >>> @add_schema(base_schema=BaseSchema) ... @dataclasses.dataclass ... class Artist: ... names: Tuple[str, str] >>> artist = Artist.Schema().loads('{"NAMES": ["Martin", "Ramirez"]}') >>> artist Artist(names=('Martin', 'Ramirez')) """ def decorator(clazz: type[_U]) -> type[_U]: clazz.Schema = class_schema(clazz, base_schema) # type: ignore return clazz return decorator(_cls) if _cls else decorator
[docs]def class_schema( clazz: type, base_schema: type[marshmallow.Schema] | None = None ) -> type[marshmallow.Schema]: """ Convert a class to a marshmallow schema :param clazz: A python class (may be a dataclass) :param base_schema: marshmallow schema used as a base class when deriving dataclass schema :return: A marshmallow Schema corresponding to the dataclass .. note:: All the arguments supported by marshmallow field classes are can be passed in the `metadata` dictionary of a field. If you want to use a custom marshmallow field (one that has no equivalent python type), you can pass it as the ``marshmallow_field`` key in the metadata dictionary. >>> import typing >>> Meters = typing.NewType('Meters', float) >>> @dataclasses.dataclass() ... class Building: ... height: Optional[Meters] ... name: str = dataclasses.field(default="anonymous") ... class Meta: ... ordered = True ... >>> class_schema(Building) # Returns a marshmallow schema class (not an instance) <class 'marshmallow.schema.Building'> >>> @dataclasses.dataclass() ... class City: ... name: str = dataclasses.field(metadata={'required':True}) ... best_building: Building # Reference to another dataclasses. A schema will be created for it too. ... other_buildings: List[Building] = dataclasses.field(default_factory=lambda: []) ... >>> citySchema = class_schema(City)() >>> city = citySchema.load({"name":"Paris", "best_building": {"name": "Eiffel Tower"}}) >>> city City(name='Paris', best_building=Building(height=None, name='Eiffel Tower'), other_buildings=[]) >>> citySchema.load({"name":"Paris"}) Traceback (most recent call last): ... marshmallow.exceptions.ValidationError: {'best_building': ['Missing data for required field.']} >>> city_json = citySchema.dump(city) >>> city_json['best_building'] # We get an OrderedDict because we specified order = True in the Meta class OrderedDict([('height', None), ('name', 'Eiffel Tower')]) >>> @dataclasses.dataclass() ... class Person: ... name: str = dataclasses.field(default="Anonymous") ... friends: List['Person'] = dataclasses.field(default_factory=lambda:[]) # Recursive field ... >>> person = class_schema(Person)().load({ ... "friends": [{"name": "Roger Boucher"}] ... }) >>> person Person(name='Anonymous', friends=[Person(name='Roger Boucher', friends=[])]) >>> @dataclasses.dataclass() ... class C: ... important: int = dataclasses.field(init=True, default=0) ... # Only fields that are in the __init__ method will be added: ... unimportant: int = dataclasses.field(init=False, default=0) ... >>> c = class_schema(C)().load({ ... "important": 9, # This field will be imported ... "unimportant": 9 # This field will NOT be imported ... }, unknown=marshmallow.EXCLUDE) >>> c C(important=9, unimportant=0) >>> @dataclasses.dataclass ... class Website: ... url:str = dataclasses.field(metadata = { ... "marshmallow_field": marshmallow.fields.Url() # Custom marshmallow field ... }) ... >>> class_schema(Website)().load({"url": "I am not a good URL !"}) Traceback (most recent call last): ... marshmallow.exceptions.ValidationError: {'url': ['Not a valid URL.']} >>> @dataclasses.dataclass ... class NeverValid: ... @marshmallow.validates_schema ... def validate(self, data, **_): ... raise marshmallow.ValidationError('never valid') ... >>> class_schema(NeverValid)().load({}) Traceback (most recent call last): ... marshmallow.exceptions.ValidationError: {'_schema': ['never valid']} >>> # noinspection PyTypeChecker >>> class_schema(None) # unsupported type Traceback (most recent call last): ... TypeError: None is not a dataclass and cannot be turned into one. >>> @dataclasses.dataclass ... class Anything: ... name: str ... @marshmallow.validates('name') ... def validates(self, value): ... if len(value) > 5: raise marshmallow.ValidationError("Name too long") >>> class_schema(Anything)().load({"name": "aaaaaargh"}) Traceback (most recent call last): ... marshmallow.exceptions.ValidationError: {'name': ['Name too long']} """ return _proxied_class_schema(clazz, base_schema)
@lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE) def _proxied_class_schema( clazz: type, base_schema: type[marshmallow.Schema] | None = None ) -> type[marshmallow.Schema]: try: # noinspection PyDataclass fields: tuple[dataclasses.Field, ...] = dataclasses.fields(clazz) except TypeError: # Not a dataclass try: return class_schema(dataclasses.dataclass(clazz), base_schema) except Exception: raise TypeError( f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one." ) # Copy all marshmallow hooks and whitelisted members of the dataclass to the schema. attributes = { k: v for k, v in inspect.getmembers(clazz) if hasattr(v, "__marshmallow_hook__") or k in MEMBERS_WHITELIST } # Update the schema members to contain marshmallow fields instead of dataclass fields attributes.update( ( field.name, field_for_schema( field.type, _get_field_default(field), field.metadata, clazz, field.name, base_schema, ), ) for field in fields if field.init ) schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes) return cast(type[marshmallow.Schema], schema_class) def _field_by_type( typ: type | Any, base_schema: type[marshmallow.Schema] | None ) -> type[marshmallow.fields.Field] | None: if typ is Any: return lambda **x: marshmallow.fields.Raw(**{**x, "allow_none": True}) else: return ( base_schema and base_schema.TYPE_MAPPING.get(typ) ) or marshmallow.Schema.TYPE_MAPPING.get(typ) class SemiStrictIntegerField(marshmallow.fields.Field): """A "semi-strict" integer field that accepts integers and strings convertible to integers but not floats.""" def _serialize(self, value: Any, attr: str, obj: Any, **kwargs): raise NotImplementedError def _deserialize( self, value: Any, attr: str | None, data: Mapping[str, Any] | None, **kwargs, ): if isinstance(value, int): return value if isinstance(value, str): try: return int(value) except ValueError: pass raise self.make_error("validator_failed")
[docs]def field_for_schema( typ: type, default=dataclasses.MISSING, metadata: Mapping[str, Any] = None, clazz: type = None, name=None, base_schema: type[marshmallow.Schema] | None = None, ) -> marshmallow.fields.Field: """ Get a marshmallow Field corresponding to the given python type. The metadata of the dataclass field is used as arguments to the marshmallow Field. :param typ: The type for which a field should be generated :param default: value to use for (de)serialization when the field is missing :param metadata: Additional parameters to pass to the marshmallow field constructor :param base_schema: marshmallow schema used as a base class when deriving dataclass schema >>> int_field = field_for_schema(int, default=9, metadata=dict(required=True)) >>> int_field.__class__ <class 'marshmallow.fields.Integer'> >>> int_field.default 9 >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__ <class 'marshmallow.fields.Url'> """ metadata = {} if metadata is None else dict(metadata) if default is not marshmallow.missing and default is not dataclasses.MISSING: metadata.setdefault("default", default) # 'missing' must not be set for required fields. if not metadata.get("required"): metadata.setdefault("missing", default) elif default is dataclasses.MISSING: metadata.setdefault("required", True) # If the field was already defined by the user predefined_field = metadata.get("marshmallow_field") if predefined_field: return predefined_field # Generic types specified without type arguments if typ is list: typ = list[Any] elif typ is dict: typ = dict[Any, Any] # Base types field = _field_by_type(typ, base_schema) if field: if field is Integer: # We want a custom integer field because otherwise floats would get silently rounded to integers. # The built-in Integer field with strict=True option is too strict because it doesn't allow strings # that are convertible to ints. field = SemiStrictIntegerField check_default(clazz, default, typ, name) return field(**metadata) if typing_inspect.is_literal_type(typ): arguments = typing_inspect.get_args(typ) return marshmallow.fields.Raw( validate=( marshmallow.validate.Equal(arguments[0]) if len(arguments) == 1 else marshmallow.validate.OneOf(arguments) ), **metadata, ) # Generic types origin = typing_inspect_ext.get_origin(typ) if origin: arguments = typing_inspect_ext.get_args(typ, True) if origin in (list, list): check_default(clazz, default, list, name) child_type = field_for_schema( arguments[0], clazz=clazz, base_schema=base_schema ) if metadata.get("list_type", None) == "delimited": return webargs.fields.DelimitedList(child_type) return marshmallow.fields.List(child_type, **metadata) if origin in (tuple, tuple): check_default(clazz, default, tuple, name) children = tuple( field_for_schema(arg, clazz=clazz, base_schema=base_schema) for arg in arguments ) return marshmallow.fields.Tuple(children, **metadata) elif origin in (dict, dict): check_default(clazz, default, dict, name) return marshmallow.fields.Dict( keys=field_for_schema( arguments[0], clazz=clazz, base_schema=base_schema ), values=field_for_schema( arguments[1], clazz=clazz, base_schema=base_schema ), **metadata, ) elif typing_inspect_ext.is_union_type(typ): has_none = typing_inspect_ext.is_optional_type(typ) or any( subtyp is Any for subtyp in arguments ) if has_none: metadata["allow_none"] = True if default is not dataclasses.MISSING: metadata["required"] = False subfields = [ field_for_schema( subtyp, metadata=metadata, clazz=clazz, base_schema=base_schema ) for subtyp in arguments if subtyp is not NoneType ] import marshmallow_union if default is not dataclasses.MISSING: if not any(isinstance_noexcept(default, t) for t in arguments): report_default_error(clazz, default, typ, name) return marshmallow_union.Union(subfields, **metadata) check_default(clazz, default, typ, name) # typing.NewType has a __supertype__ attribute newtype_supertype = getattr(typ, "__supertype__", None) if newtype_supertype: # Add the information coming our custom NewType implementation metadata = { "description": typ.__name__, **getattr(typ, "_marshmallow_args", {}), **metadata, } field = getattr(typ, "_marshmallow_field", None) if field: return field(**metadata) else: return field_for_schema( newtype_supertype, metadata=metadata, default=default, clazz=clazz, base_schema=base_schema, ) # enumerations if isinstance(typ, EnumMeta): import marshmallow_enum return marshmallow_enum.EnumField(typ, **metadata) # generic types if type(typ) is TypeVar: b = typing_inspect.get_generic_bases(clazz)[0] type_index = typing_inspect_ext.get_args( typing_inspect.get_generic_bases(typing_inspect_ext.get_origin(b))[0] ).index(typ) instantiated_type = typing_inspect_ext.get_args(b)[type_index] return field_for_schema( instantiated_type, metadata=metadata, default=default, clazz=clazz, base_schema=base_schema, ) # Nested marshmallow dataclass nested_schema = getattr(typ, "Schema", None) # Nested dataclasses forward_reference = getattr(typ, "__forward_arg__", None) nested = ( nested_schema or forward_reference or class_schema(typ, base_schema=base_schema) ) return marshmallow.fields.Nested(nested, **metadata)
def isinstance_noexcept(default: Any, t: type): try: return isinstance(default, t) except TypeError: return False def check_default(clazz: type, default: Any, typ: type, name: str | None): if default is not dataclasses.MISSING: if not isinstance(default, typ): report_default_error(clazz, default, typ, name) def report_default_error(clazz: type, default: Any, typ: type, name: str | None): if not name: raise TypeError( f"Invalid default value {default} supplied in class {clazz.__name__} for {typ}" ) else: raise TypeError( f"Invalid default value {default} supplied for field {name} in class {clazz.__name__} for {typ}" ) def _base_schema( clazz: type, base_schema: type[marshmallow.Schema] | None = None ) -> type[marshmallow.Schema]: """ Base schema factory that creates a schema for `clazz` derived either from `base_schema` or `BaseSchema` """ # Remove `type: ignore` when mypy handles dynamic base classes # https://github.com/python/mypy/issues/2813 class BaseSchema(base_schema or marshmallow.Schema): # type: ignore @marshmallow.post_load def make_data_class(self, data, **_): return clazz(**data) return BaseSchema def _get_field_default(field: dataclasses.Field): """ Return a marshmallow default value given a dataclass default value >>> _get_field_default(dataclasses.field()) <marshmallow.missing> """ # Remove `type: ignore` when https://github.com/python/mypy/issues/6910 is fixed default_factory = field.default_factory # type: ignore if default_factory is not dataclasses.MISSING: return default_factory() return field.default if __name__ == "__main__": import doctest doctest.testmod(verbose=True)