docs for muutils v0.6.17
View Source on GitHub

muutils.json_serialize.serializable_dataclass

save and load objects to and from json or compatible formats in a recoverable way

d = dataclasses.asdict(my_obj) will give you a dict, but if some fields are not json-serializable, you will get an error when you call json.dumps(d). This module provides a way around that.

Instead, you define your class:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True

  1"""save and load objects to and from json or compatible formats in a recoverable way
  2
  3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable,
  4you will get an error when you call `json.dumps(d)`. This module provides a way around that.
  5
  6Instead, you define your class:
  7
  8```python
  9@serializable_dataclass
 10class MyClass(SerializableDataclass):
 11    a: int
 12    b: str
 13```
 14
 15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
 16
 17    >>> my_obj = MyClass(a=1, b="q")
 18    >>> s = json.dumps(my_obj.serialize())
 19    >>> s
 20    '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
 21    >>> read_obj = MyClass.load(json.loads(s))
 22    >>> read_obj == my_obj
 23    True
 24
 25This isn't too impressive on its own, but it gets more useful when you have nested classses,
 26or fields that are not json-serializable by default:
 27
 28```python
 29@serializable_dataclass
 30class NestedClass(SerializableDataclass):
 31    x: str
 32    y: MyClass
 33    act_fun: torch.nn.Module = serializable_field(
 34        default=torch.nn.ReLU(),
 35        serialization_fn=lambda x: str(x),
 36        deserialize_fn=lambda x: getattr(torch.nn, x)(),
 37    )
 38```
 39
 40which gives us:
 41
 42    >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
 43    >>> s = json.dumps(nc.serialize())
 44    >>> s
 45    '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
 46    >>> read_nc = NestedClass.load(json.loads(s))
 47    >>> read_nc == nc
 48    True
 49
 50"""
 51
 52from __future__ import annotations
 53
 54import abc
 55import dataclasses
 56import functools
 57import json
 58import sys
 59import typing
 60import warnings
 61from typing import Any, Optional, Type, TypeVar
 62
 63from muutils.errormode import ErrorMode
 64from muutils.validate_type import validate_type
 65from muutils.json_serialize.serializable_field import (
 66    SerializableField,
 67    serializable_field,
 68)
 69from muutils.json_serialize.util import array_safe_eq, dc_eq
 70
 71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access
 72
 73
 74def _dataclass_transform_mock(
 75    *,
 76    eq_default: bool = True,
 77    order_default: bool = False,
 78    kw_only_default: bool = False,
 79    frozen_default: bool = False,
 80    field_specifiers: tuple[type[Any] | typing.Callable[..., Any], ...] = (),
 81    **kwargs: Any,
 82) -> typing.Callable:
 83    "mock `typing.dataclass_transform` for python <3.11"
 84
 85    def decorator(cls_or_fn):
 86        cls_or_fn.__dataclass_transform__ = {
 87            "eq_default": eq_default,
 88            "order_default": order_default,
 89            "kw_only_default": kw_only_default,
 90            "frozen_default": frozen_default,
 91            "field_specifiers": field_specifiers,
 92            "kwargs": kwargs,
 93        }
 94        return cls_or_fn
 95
 96    return decorator
 97
 98
 99dataclass_transform: typing.Callable
100if sys.version_info < (3, 11):
101    dataclass_transform = _dataclass_transform_mock
102else:
103    dataclass_transform = typing.dataclass_transform
104
105
106T = TypeVar("T")
107
108
109class CantGetTypeHintsWarning(UserWarning):
110    "special warning for when we can't get type hints"
111
112    pass
113
114
115class ZanjMissingWarning(UserWarning):
116    "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
117
118    pass
119
120
121_zanj_loading_needs_import: bool = True
122"flag to keep track of if we have successfully imported ZANJ"
123
124
125def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]):
126    """Register a serializable dataclass with the ZANJ import
127
128    this allows `ZANJ().read()` to load the class and not just return plain dicts
129
130
131    # TODO: there is some duplication here with register_loader_handler
132    """
133    global _zanj_loading_needs_import
134
135    if _zanj_loading_needs_import:
136        try:
137            from zanj.loading import (  # type: ignore[import]
138                LoaderHandler,
139                register_loader_handler,
140            )
141        except ImportError:
142            warnings.warn(
143                "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
144                ZanjMissingWarning,
145            )
146            return
147
148    _format: str = f"{cls.__name__}(SerializableDataclass)"
149    lh: LoaderHandler = LoaderHandler(
150        check=lambda json_item, path=None, z=None: (  # type: ignore
151            isinstance(json_item, dict)
152            and "__format__" in json_item
153            and json_item["__format__"].startswith(_format)
154        ),
155        load=lambda json_item, path=None, z=None: cls.load(json_item),  # type: ignore
156        uid=_format,
157        source_pckg=cls.__module__,
158        desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
159    )
160
161    register_loader_handler(lh)
162
163    return lh
164
165
166_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN
167_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT
168
169
170class FieldIsNotInitOrSerializeWarning(UserWarning):
171    pass
172
173
174def SerializableDataclass__validate_field_type(
175    self: SerializableDataclass,
176    field: SerializableField | str,
177    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
178) -> bool:
179    """given a dataclass, check the field matches the type hint
180
181    this function is written to `SerializableDataclass.validate_field_type`
182
183    # Parameters:
184     - `self : SerializableDataclass`
185       `SerializableDataclass` instance
186     - `field : SerializableField | str`
187        field to validate, will get from `self.__dataclass_fields__` if an `str`
188     - `on_typecheck_error : ErrorMode`
189        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
190       (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
191
192    # Returns:
193     - `bool`
194        if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
195    """
196    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
197
198    # get field
199    _field: SerializableField
200    if isinstance(field, str):
201        _field = self.__dataclass_fields__[field]  # type: ignore[attr-defined]
202    else:
203        _field = field
204
205    # do nothing case
206    if not _field.assert_type:
207        return True
208
209    # if field is not `init` or not `serialize`, skip but warn
210    # TODO: how to handle fields which are not `init` or `serialize`?
211    if not _field.init or not _field.serialize:
212        warnings.warn(
213            f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
214            FieldIsNotInitOrSerializeWarning,
215        )
216        return True
217
218    assert isinstance(
219        _field, SerializableField
220    ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
221
222    # get field type hints
223    try:
224        field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
225    except KeyError as e:
226        on_typecheck_error.process(
227            (
228                f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
229                + f"{get_cls_type_hints(self.__class__) = }\n"
230                + f"Python version is {sys.version_info = }. You can:\n"
231                + f"  - disable `assert_type`. Currently: {_field.assert_type = }\n"
232                + f"  - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
233                + "  - use python 3.9.x or higher\n"
234                + "  - specify custom type validation function via `custom_typecheck_fn`\n"
235            ),
236            except_cls=TypeError,
237            except_from=e,
238        )
239        return False
240
241    # get the value
242    value: Any = getattr(self, _field.name)
243
244    # validate the type
245    try:
246        type_is_valid: bool
247        # validate the type with the default type validator
248        if _field.custom_typecheck_fn is None:
249            type_is_valid = validate_type(value, field_type_hint)
250        # validate the type with a custom type validator
251        else:
252            type_is_valid = _field.custom_typecheck_fn(field_type_hint)
253
254        return type_is_valid
255
256    except Exception as e:
257        on_typecheck_error.process(
258            "exception while validating type: "
259            + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
260            except_cls=ValueError,
261            except_from=e,
262        )
263        return False
264
265
266def SerializableDataclass__validate_fields_types__dict(
267    self: SerializableDataclass,
268    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
269) -> dict[str, bool]:
270    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
271
272    returns a dict of field names to bools, where the bool is if the field type is valid
273    """
274    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
275
276    # if except, bundle the exceptions
277    results: dict[str, bool] = dict()
278    exceptions: dict[str, Exception] = dict()
279
280    # for each field in the class
281    cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self)  # type: ignore[arg-type, assignment]
282    for field in cls_fields:
283        try:
284            results[field.name] = self.validate_field_type(field, on_typecheck_error)
285        except Exception as e:
286            results[field.name] = False
287            exceptions[field.name] = e
288
289    # figure out what to do with the exceptions
290    if len(exceptions) > 0:
291        on_typecheck_error.process(
292            f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
293            + "\n\t"
294            + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
295            except_cls=ValueError,
296            # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
297            except_from=list(exceptions.values())[0],
298        )
299
300    return results
301
302
303def SerializableDataclass__validate_fields_types(
304    self: SerializableDataclass,
305    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
306) -> bool:
307    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
308    return all(
309        SerializableDataclass__validate_fields_types__dict(
310            self, on_typecheck_error=on_typecheck_error
311        ).values()
312    )
313
314
315@dataclass_transform(
316    field_specifiers=(serializable_field, SerializableField),
317)
318class SerializableDataclass(abc.ABC):
319    """Base class for serializable dataclasses
320
321    only for linting and type checking, still need to call `serializable_dataclass` decorator
322
323    # Usage:
324
325    ```python
326    @serializable_dataclass
327    class MyClass(SerializableDataclass):
328        a: int
329        b: str
330    ```
331
332    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
333
334        >>> my_obj = MyClass(a=1, b="q")
335        >>> s = json.dumps(my_obj.serialize())
336        >>> s
337        '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
338        >>> read_obj = MyClass.load(json.loads(s))
339        >>> read_obj == my_obj
340        True
341
342    This isn't too impressive on its own, but it gets more useful when you have nested classses,
343    or fields that are not json-serializable by default:
344
345    ```python
346    @serializable_dataclass
347    class NestedClass(SerializableDataclass):
348        x: str
349        y: MyClass
350        act_fun: torch.nn.Module = serializable_field(
351            default=torch.nn.ReLU(),
352            serialization_fn=lambda x: str(x),
353            deserialize_fn=lambda x: getattr(torch.nn, x)(),
354        )
355    ```
356
357    which gives us:
358
359        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
360        >>> s = json.dumps(nc.serialize())
361        >>> s
362        '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
363        >>> read_nc = NestedClass.load(json.loads(s))
364        >>> read_nc == nc
365        True
366    """
367
368    def serialize(self) -> dict[str, Any]:
369        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
370        raise NotImplementedError(
371            f"decorate {self.__class__ = } with `@serializable_dataclass`"
372        )
373
374    @classmethod
375    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
376        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
377        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
378
379    def validate_fields_types(
380        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
381    ) -> bool:
382        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
383        return SerializableDataclass__validate_fields_types(
384            self, on_typecheck_error=on_typecheck_error
385        )
386
387    def validate_field_type(
388        self,
389        field: "SerializableField|str",
390        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
391    ) -> bool:
392        """given a dataclass, check the field matches the type hint"""
393        return SerializableDataclass__validate_field_type(
394            self, field, on_typecheck_error=on_typecheck_error
395        )
396
397    def __eq__(self, other: Any) -> bool:
398        return dc_eq(self, other)
399
400    def __hash__(self) -> int:
401        "hashes the json-serialized representation of the class"
402        return hash(json.dumps(self.serialize()))
403
404    def diff(
405        self, other: "SerializableDataclass", of_serialized: bool = False
406    ) -> dict[str, Any]:
407        """get a rich and recursive diff between two instances of a serializable dataclass
408
409        ```python
410        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
411        {'b': {'self': 2, 'other': 3}}
412        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
413        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
414        ```
415
416        # Parameters:
417         - `other : SerializableDataclass`
418           other instance to compare against
419         - `of_serialized : bool`
420           if true, compare serialized data and not raw values
421           (defaults to `False`)
422
423        # Returns:
424         - `dict[str, Any]`
425
426
427        # Raises:
428         - `ValueError` : if the instances are not of the same type
429         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
430        """
431        # match types
432        if type(self) is not type(other):
433            raise ValueError(
434                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
435            )
436
437        # initialize the diff result
438        diff_result: dict = {}
439
440        # if they are the same, return the empty diff
441        if self == other:
442            return diff_result
443
444        # if we are working with serialized data, serialize the instances
445        if of_serialized:
446            ser_self: dict = self.serialize()
447            ser_other: dict = other.serialize()
448
449        # for each field in the class
450        for field in dataclasses.fields(self):  # type: ignore[arg-type]
451            # skip fields that are not for comparison
452            if not field.compare:
453                continue
454
455            # get values
456            field_name: str = field.name
457            self_value = getattr(self, field_name)
458            other_value = getattr(other, field_name)
459
460            # if the values are both serializable dataclasses, recurse
461            if isinstance(self_value, SerializableDataclass) and isinstance(
462                other_value, SerializableDataclass
463            ):
464                nested_diff: dict = self_value.diff(
465                    other_value, of_serialized=of_serialized
466                )
467                if nested_diff:
468                    diff_result[field_name] = nested_diff
469            # only support serializable dataclasses
470            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
471                other_value
472            ):
473                raise ValueError("Non-serializable dataclass is not supported")
474            else:
475                # get the values of either the serialized or the actual values
476                self_value_s = ser_self[field_name] if of_serialized else self_value
477                other_value_s = ser_other[field_name] if of_serialized else other_value
478                # compare the values
479                if not array_safe_eq(self_value_s, other_value_s):
480                    diff_result[field_name] = {"self": self_value, "other": other_value}
481
482        # return the diff result
483        return diff_result
484
485    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
486        """update the instance from a nested dict, useful for configuration from command line args
487
488        # Parameters:
489            - `nested_dict : dict[str, Any]`
490                nested dict to update the instance with
491        """
492        for field in dataclasses.fields(self):  # type: ignore[arg-type]
493            field_name: str = field.name
494            self_value = getattr(self, field_name)
495
496            if field_name in nested_dict:
497                if isinstance(self_value, SerializableDataclass):
498                    self_value.update_from_nested_dict(nested_dict[field_name])
499                else:
500                    setattr(self, field_name, nested_dict[field_name])
501
502    def __copy__(self) -> "SerializableDataclass":
503        "deep copy by serializing and loading the instance to json"
504        return self.__class__.load(json.loads(json.dumps(self.serialize())))
505
506    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
507        "deep copy by serializing and loading the instance to json"
508        return self.__class__.load(json.loads(json.dumps(self.serialize())))
509
510
511# cache this so we don't have to keep getting it
512# TODO: are the types hashable? does this even make sense?
513@functools.lru_cache(typed=True)
514def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]:
515    "cached typing.get_type_hints for a class"
516    return typing.get_type_hints(cls)
517
518
519def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]:
520    "helper function to get type hints for a class"
521    cls_type_hints: dict[str, Any]
522    try:
523        cls_type_hints = get_cls_type_hints_cached(cls)  # type: ignore
524        if len(cls_type_hints) == 0:
525            cls_type_hints = typing.get_type_hints(cls)
526
527        if len(cls_type_hints) == 0:
528            raise ValueError(f"empty type hints for {cls.__name__ = }")
529    except (TypeError, NameError, ValueError) as e:
530        raise TypeError(
531            f"Cannot get type hints for {cls = }\n"
532            + f"  Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
533            + f"  {dataclasses.fields(cls) = }\n"  # type: ignore[arg-type]
534            + f"  {e = }"
535        ) from e
536
537    return cls_type_hints
538
539
540class KWOnlyError(NotImplementedError):
541    "kw-only dataclasses are not supported in python <3.9"
542
543    pass
544
545
546class FieldError(ValueError):
547    "base class for field errors"
548
549    pass
550
551
552class NotSerializableFieldException(FieldError):
553    "field is not a `SerializableField`"
554
555    pass
556
557
558class FieldSerializationError(FieldError):
559    "error while serializing a field"
560
561    pass
562
563
564class FieldLoadingError(FieldError):
565    "error while loading a field"
566
567    pass
568
569
570class FieldTypeMismatchError(FieldError, TypeError):
571    "error when a field type does not match the type hint"
572
573    pass
574
575
576@dataclass_transform(
577    field_specifiers=(serializable_field, SerializableField),
578)
579def serializable_dataclass(
580    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
581    _cls=None,  # type: ignore
582    *,
583    init: bool = True,
584    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
585    eq: bool = True,
586    order: bool = False,
587    unsafe_hash: bool = False,
588    frozen: bool = False,
589    properties_to_serialize: Optional[list[str]] = None,
590    register_handler: bool = True,
591    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
592    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
593    **kwargs,
594):
595    """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass`
596
597    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
598
599    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs
600
601    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
602
603    Examines PEP 526 `__annotations__` to determine fields.
604
605    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
606
607    ```python
608    @serializable_dataclass(kw_only=True)
609    class Myclass(SerializableDataclass):
610        a: int
611        b: str
612    ```
613    ```python
614    >>> Myclass(a=1, b="q").serialize()
615    {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
616    ```
617
618    # Parameters:
619     - `_cls : _type_`
620       class to decorate. don't pass this arg, just use this as a decorator
621       (defaults to `None`)
622     - `init : bool`
623       (defaults to `True`)
624     - `repr : bool`
625       (defaults to `True`)
626     - `order : bool`
627       (defaults to `False`)
628     - `unsafe_hash : bool`
629       (defaults to `False`)
630     - `frozen : bool`
631       (defaults to `False`)
632     - `properties_to_serialize : Optional[list[str]]`
633       **SerializableDataclass only:** which properties to add to the serialized data dict
634       (defaults to `None`)
635     - `register_handler : bool`
636        **SerializableDataclass only:** if true, register the class with ZANJ for loading
637       (defaults to `True`)
638     - `on_typecheck_error : ErrorMode`
639        **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
640     - `on_typecheck_mismatch : ErrorMode`
641        **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
642
643    # Returns:
644     - `_type_`
645       the decorated class
646
647    # Raises:
648     - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
649     - `NotSerializableFieldException` : if a field is not a `SerializableField`
650     - `FieldSerializationError` : if there is an error serializing a field
651     - `AttributeError` : if a property is not found on the class
652     - `FieldLoadingError` : if there is an error loading a field
653    """
654    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
655    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
656    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
657
658    if properties_to_serialize is None:
659        _properties_to_serialize: list = list()
660    else:
661        _properties_to_serialize = properties_to_serialize
662
663    def wrap(cls: Type[T]) -> Type[T]:
664        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
665        for field_name, field_type in cls.__annotations__.items():
666            field_value = getattr(cls, field_name, None)
667            if not isinstance(field_value, SerializableField):
668                if isinstance(field_value, dataclasses.Field):
669                    # Convert the field to a SerializableField while preserving properties
670                    field_value = SerializableField.from_Field(field_value)
671                else:
672                    # Create a new SerializableField
673                    field_value = serializable_field()
674                setattr(cls, field_name, field_value)
675
676        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
677        if sys.version_info < (3, 10):
678            if "kw_only" in kwargs:
679                if kwargs["kw_only"] == True:  # noqa: E712
680                    raise KWOnlyError("kw_only is not supported in python >=3.9")
681                else:
682                    del kwargs["kw_only"]
683
684        # call `dataclasses.dataclass` to set some stuff up
685        cls = dataclasses.dataclass(  # type: ignore[call-overload]
686            cls,
687            init=init,
688            repr=repr,
689            eq=eq,
690            order=order,
691            unsafe_hash=unsafe_hash,
692            frozen=frozen,
693            **kwargs,
694        )
695
696        # copy these to the class
697        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
698
699        # ======================================================================
700        # define `serialize` func
701        # done locally since it depends on args to the decorator
702        # ======================================================================
703        def serialize(self) -> dict[str, Any]:
704            result: dict[str, Any] = {
705                "__format__": f"{self.__class__.__name__}(SerializableDataclass)"
706            }
707            # for each field in the class
708            for field in dataclasses.fields(self):  # type: ignore[arg-type]
709                # need it to be our special SerializableField
710                if not isinstance(field, SerializableField):
711                    raise NotSerializableFieldException(
712                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
713                        f"but a {type(field)} "
714                        "this state should be inaccessible, please report this bug!"
715                    )
716
717                # try to save it
718                if field.serialize:
719                    try:
720                        # get the val
721                        value = getattr(self, field.name)
722                        # if it is a serializable dataclass, serialize it
723                        if isinstance(value, SerializableDataclass):
724                            value = value.serialize()
725                        # if the value has a serialization function, use that
726                        if hasattr(value, "serialize") and callable(value.serialize):
727                            value = value.serialize()
728                        # if the field has a serialization function, use that
729                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
730                        elif field.serialization_fn:
731                            value = field.serialization_fn(value)
732
733                        # store the value in the result
734                        result[field.name] = value
735                    except Exception as e:
736                        raise FieldSerializationError(
737                            "\n".join(
738                                [
739                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
740                                    f"{field = }",
741                                    f"{value = }",
742                                    f"{self = }",
743                                ]
744                            )
745                        ) from e
746
747            # store each property if we can get it
748            for prop in self._properties_to_serialize:
749                if hasattr(cls, prop):
750                    value = getattr(self, prop)
751                    result[prop] = value
752                else:
753                    raise AttributeError(
754                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
755                        + f"but it is in {self._properties_to_serialize = }"
756                        + f"\n{self = }"
757                    )
758
759            return result
760
761        # ======================================================================
762        # define `load` func
763        # done locally since it depends on args to the decorator
764        # ======================================================================
765        # mypy thinks this isnt a classmethod
766        @classmethod  # type: ignore[misc]
767        def load(cls, data: dict[str, Any] | T) -> Type[T]:
768            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
769            if isinstance(data, cls):
770                return data
771
772            assert isinstance(
773                data, typing.Mapping
774            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
775
776            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
777
778            # initialize dict for keeping what we will pass to the constructor
779            ctor_kwargs: dict[str, Any] = dict()
780
781            # iterate over the fields of the class
782            for field in dataclasses.fields(cls):
783                # check if the field is a SerializableField
784                assert isinstance(
785                    field, SerializableField
786                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
787
788                # check if the field is in the data and if it should be initialized
789                if (field.name in data) and field.init:
790                    # get the value, we will be processing it
791                    value: Any = data[field.name]
792
793                    # get the type hint for the field
794                    field_type_hint: Any = cls_type_hints.get(field.name, None)
795
796                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
797                    if field.deserialize_fn:
798                        # if it has a deserialization function, use that
799                        value = field.deserialize_fn(value)
800                    elif field.loading_fn:
801                        # if it has a loading function, use that
802                        value = field.loading_fn(data)
803                    elif (
804                        field_type_hint is not None
805                        and hasattr(field_type_hint, "load")
806                        and callable(field_type_hint.load)
807                    ):
808                        # if no loading function but has a type hint with a load method, use that
809                        if isinstance(value, dict):
810                            value = field_type_hint.load(value)
811                        else:
812                            raise FieldLoadingError(
813                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
814                            )
815                    else:
816                        # assume no loading needs to happen, keep `value` as-is
817                        pass
818
819                    # store the value in the constructor kwargs
820                    ctor_kwargs[field.name] = value
821
822            # create a new instance of the class with the constructor kwargs
823            output: cls = cls(**ctor_kwargs)
824
825            # validate the types of the fields if needed
826            if on_typecheck_mismatch != ErrorMode.IGNORE:
827                fields_valid: dict[str, bool] = (
828                    SerializableDataclass__validate_fields_types__dict(
829                        output,
830                        on_typecheck_error=on_typecheck_error,
831                    )
832                )
833
834                # if there are any fields that are not valid, raise an error
835                if not all(fields_valid.values()):
836                    msg: str = (
837                        f"Type mismatch in fields of {cls.__name__}:\n"
838                        + "\n".join(
839                            [
840                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
841                                for k, v in fields_valid.items()
842                                if not v
843                            ]
844                        )
845                    )
846
847                    on_typecheck_mismatch.process(
848                        msg, except_cls=FieldTypeMismatchError
849                    )
850
851            # return the new instance
852            return output
853
854        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
855        # type is `Callable[[T], dict]`
856        cls.serialize = serialize  # type: ignore[attr-defined]
857        # type is `Callable[[dict], T]`
858        cls.load = load  # type: ignore[attr-defined]
859        # type is `Callable[[T, ErrorMode], bool]`
860        cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined]
861
862        # type is `Callable[[T, T], bool]`
863        if not hasattr(cls, "__eq__"):
864            cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
865
866        # Register the class with ZANJ
867        if register_handler:
868            zanj_register_loader_serializable_dataclass(cls)
869
870        return cls
871
872    if _cls is None:
873        return wrap
874    else:
875        return wrap(_cls)

def dataclass_transform( *, eq_default: bool = True, order_default: bool = False, kw_only_default: bool = False, frozen_default: bool = False, field_specifiers: tuple[typing.Union[type[typing.Any], typing.Callable[..., typing.Any]], ...] = (), **kwargs: Any) -> <class '_IdentityCallable'>:
3275def dataclass_transform(
3276    *,
3277    eq_default: bool = True,
3278    order_default: bool = False,
3279    kw_only_default: bool = False,
3280    frozen_default: bool = False,
3281    field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (),
3282    **kwargs: Any,
3283) -> _IdentityCallable:
3284    """Decorator to mark an object as providing dataclass-like behaviour.
3285
3286    The decorator can be applied to a function, class, or metaclass.
3287
3288    Example usage with a decorator function::
3289
3290        @dataclass_transform()
3291        def create_model[T](cls: type[T]) -> type[T]:
3292            ...
3293            return cls
3294
3295        @create_model
3296        class CustomerModel:
3297            id: int
3298            name: str
3299
3300    On a base class::
3301
3302        @dataclass_transform()
3303        class ModelBase: ...
3304
3305        class CustomerModel(ModelBase):
3306            id: int
3307            name: str
3308
3309    On a metaclass::
3310
3311        @dataclass_transform()
3312        class ModelMeta(type): ...
3313
3314        class ModelBase(metaclass=ModelMeta): ...
3315
3316        class CustomerModel(ModelBase):
3317            id: int
3318            name: str
3319
3320    The ``CustomerModel`` classes defined above will
3321    be treated by type checkers similarly to classes created with
3322    ``@dataclasses.dataclass``.
3323    For example, type checkers will assume these classes have
3324    ``__init__`` methods that accept ``id`` and ``name``.
3325
3326    The arguments to this decorator can be used to customize this behavior:
3327    - ``eq_default`` indicates whether the ``eq`` parameter is assumed to be
3328        ``True`` or ``False`` if it is omitted by the caller.
3329    - ``order_default`` indicates whether the ``order`` parameter is
3330        assumed to be True or False if it is omitted by the caller.
3331    - ``kw_only_default`` indicates whether the ``kw_only`` parameter is
3332        assumed to be True or False if it is omitted by the caller.
3333    - ``frozen_default`` indicates whether the ``frozen`` parameter is
3334        assumed to be True or False if it is omitted by the caller.
3335    - ``field_specifiers`` specifies a static list of supported classes
3336        or functions that describe fields, similar to ``dataclasses.field()``.
3337    - Arbitrary other keyword arguments are accepted in order to allow for
3338        possible future extensions.
3339
3340    At runtime, this decorator records its arguments in the
3341    ``__dataclass_transform__`` attribute on the decorated object.
3342    It has no other runtime effect.
3343
3344    See PEP 681 for more details.
3345    """
3346    def decorator(cls_or_fn):
3347        cls_or_fn.__dataclass_transform__ = {
3348            "eq_default": eq_default,
3349            "order_default": order_default,
3350            "kw_only_default": kw_only_default,
3351            "frozen_default": frozen_default,
3352            "field_specifiers": field_specifiers,
3353            "kwargs": kwargs,
3354        }
3355        return cls_or_fn
3356    return decorator

Decorator to mark an object as providing dataclass-like behaviour.

The decorator can be applied to a function, class, or metaclass.

Example usage with a decorator function::

@dataclass_transform()
def create_model[T](cls: type[T]) -> type[T]:
    ...
    return cls

@create_model
class CustomerModel:
    id: int
    name: str

On a base class::

@dataclass_transform()
class ModelBase: ...

class CustomerModel(ModelBase):
    id: int
    name: str

On a metaclass::

@dataclass_transform()
class ModelMeta(type): ...

class ModelBase(metaclass=ModelMeta): ...

class CustomerModel(ModelBase):
    id: int
    name: str

The CustomerModel classes defined above will be treated by type checkers similarly to classes created with @dataclasses.dataclass. For example, type checkers will assume these classes have __init__ methods that accept id and name.

The arguments to this decorator can be used to customize this behavior:

  • eq_default indicates whether the eq parameter is assumed to be True or False if it is omitted by the caller.
  • order_default indicates whether the order parameter is assumed to be True or False if it is omitted by the caller.
  • kw_only_default indicates whether the kw_only parameter is assumed to be True or False if it is omitted by the caller.
  • frozen_default indicates whether the frozen parameter is assumed to be True or False if it is omitted by the caller.
  • field_specifiers specifies a static list of supported classes or functions that describe fields, similar to dataclasses.field().
  • Arbitrary other keyword arguments are accepted in order to allow for possible future extensions.

At runtime, this decorator records its arguments in the __dataclass_transform__ attribute on the decorated object. It has no other runtime effect.

See PEP 681 for more details.

class CantGetTypeHintsWarning(builtins.UserWarning):
110class CantGetTypeHintsWarning(UserWarning):
111    "special warning for when we can't get type hints"
112
113    pass

special warning for when we can't get type hints

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
class ZanjMissingWarning(builtins.UserWarning):
116class ZanjMissingWarning(UserWarning):
117    "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
118
119    pass

special warning for when ZANJ is missing -- register_loader_serializable_dataclass will not work

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
def zanj_register_loader_serializable_dataclass(cls: Type[~T]):
126def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]):
127    """Register a serializable dataclass with the ZANJ import
128
129    this allows `ZANJ().read()` to load the class and not just return plain dicts
130
131
132    # TODO: there is some duplication here with register_loader_handler
133    """
134    global _zanj_loading_needs_import
135
136    if _zanj_loading_needs_import:
137        try:
138            from zanj.loading import (  # type: ignore[import]
139                LoaderHandler,
140                register_loader_handler,
141            )
142        except ImportError:
143            warnings.warn(
144                "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
145                ZanjMissingWarning,
146            )
147            return
148
149    _format: str = f"{cls.__name__}(SerializableDataclass)"
150    lh: LoaderHandler = LoaderHandler(
151        check=lambda json_item, path=None, z=None: (  # type: ignore
152            isinstance(json_item, dict)
153            and "__format__" in json_item
154            and json_item["__format__"].startswith(_format)
155        ),
156        load=lambda json_item, path=None, z=None: cls.load(json_item),  # type: ignore
157        uid=_format,
158        source_pckg=cls.__module__,
159        desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
160    )
161
162    register_loader_handler(lh)
163
164    return lh

Register a serializable dataclass with the ZANJ import

this allows ZANJ().read() to load the class and not just return plain dicts

TODO: there is some duplication here with register_loader_handler

class FieldIsNotInitOrSerializeWarning(builtins.UserWarning):
171class FieldIsNotInitOrSerializeWarning(UserWarning):
172    pass

Base class for warnings generated by user code.

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
def SerializableDataclass__validate_field_type( self: SerializableDataclass, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
175def SerializableDataclass__validate_field_type(
176    self: SerializableDataclass,
177    field: SerializableField | str,
178    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
179) -> bool:
180    """given a dataclass, check the field matches the type hint
181
182    this function is written to `SerializableDataclass.validate_field_type`
183
184    # Parameters:
185     - `self : SerializableDataclass`
186       `SerializableDataclass` instance
187     - `field : SerializableField | str`
188        field to validate, will get from `self.__dataclass_fields__` if an `str`
189     - `on_typecheck_error : ErrorMode`
190        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
191       (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
192
193    # Returns:
194     - `bool`
195        if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
196    """
197    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
198
199    # get field
200    _field: SerializableField
201    if isinstance(field, str):
202        _field = self.__dataclass_fields__[field]  # type: ignore[attr-defined]
203    else:
204        _field = field
205
206    # do nothing case
207    if not _field.assert_type:
208        return True
209
210    # if field is not `init` or not `serialize`, skip but warn
211    # TODO: how to handle fields which are not `init` or `serialize`?
212    if not _field.init or not _field.serialize:
213        warnings.warn(
214            f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
215            FieldIsNotInitOrSerializeWarning,
216        )
217        return True
218
219    assert isinstance(
220        _field, SerializableField
221    ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
222
223    # get field type hints
224    try:
225        field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
226    except KeyError as e:
227        on_typecheck_error.process(
228            (
229                f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
230                + f"{get_cls_type_hints(self.__class__) = }\n"
231                + f"Python version is {sys.version_info = }. You can:\n"
232                + f"  - disable `assert_type`. Currently: {_field.assert_type = }\n"
233                + f"  - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
234                + "  - use python 3.9.x or higher\n"
235                + "  - specify custom type validation function via `custom_typecheck_fn`\n"
236            ),
237            except_cls=TypeError,
238            except_from=e,
239        )
240        return False
241
242    # get the value
243    value: Any = getattr(self, _field.name)
244
245    # validate the type
246    try:
247        type_is_valid: bool
248        # validate the type with the default type validator
249        if _field.custom_typecheck_fn is None:
250            type_is_valid = validate_type(value, field_type_hint)
251        # validate the type with a custom type validator
252        else:
253            type_is_valid = _field.custom_typecheck_fn(field_type_hint)
254
255        return type_is_valid
256
257    except Exception as e:
258        on_typecheck_error.process(
259            "exception while validating type: "
260            + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
261            except_cls=ValueError,
262            except_from=e,
263        )
264        return False

given a dataclass, check the field matches the type hint

this function is written to SerializableDataclass.validate_field_type

Parameters:

  • self : SerializableDataclass SerializableDataclass instance
  • field : SerializableField | str field to validate, will get from self.__dataclass_fields__ if an str
  • on_typecheck_error : ErrorMode what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, the function will return False (defaults to _DEFAULT_ON_TYPECHECK_ERROR)

Returns:

  • bool if the field type is correct. False if the field type is incorrect or an exception is thrown and on_typecheck_error is ignore
def SerializableDataclass__validate_fields_types__dict( self: SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> dict[str, bool]:
267def SerializableDataclass__validate_fields_types__dict(
268    self: SerializableDataclass,
269    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
270) -> dict[str, bool]:
271    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
272
273    returns a dict of field names to bools, where the bool is if the field type is valid
274    """
275    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
276
277    # if except, bundle the exceptions
278    results: dict[str, bool] = dict()
279    exceptions: dict[str, Exception] = dict()
280
281    # for each field in the class
282    cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self)  # type: ignore[arg-type, assignment]
283    for field in cls_fields:
284        try:
285            results[field.name] = self.validate_field_type(field, on_typecheck_error)
286        except Exception as e:
287            results[field.name] = False
288            exceptions[field.name] = e
289
290    # figure out what to do with the exceptions
291    if len(exceptions) > 0:
292        on_typecheck_error.process(
293            f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
294            + "\n\t"
295            + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
296            except_cls=ValueError,
297            # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
298            except_from=list(exceptions.values())[0],
299        )
300
301    return results

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

returns a dict of field names to bools, where the bool is if the field type is valid

def SerializableDataclass__validate_fields_types( self: SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
304def SerializableDataclass__validate_fields_types(
305    self: SerializableDataclass,
306    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
307) -> bool:
308    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
309    return all(
310        SerializableDataclass__validate_fields_types__dict(
311            self, on_typecheck_error=on_typecheck_error
312        ).values()
313    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
class SerializableDataclass(abc.ABC):
316@dataclass_transform(
317    field_specifiers=(serializable_field, SerializableField),
318)
319class SerializableDataclass(abc.ABC):
320    """Base class for serializable dataclasses
321
322    only for linting and type checking, still need to call `serializable_dataclass` decorator
323
324    # Usage:
325
326    ```python
327    @serializable_dataclass
328    class MyClass(SerializableDataclass):
329        a: int
330        b: str
331    ```
332
333    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
334
335        >>> my_obj = MyClass(a=1, b="q")
336        >>> s = json.dumps(my_obj.serialize())
337        >>> s
338        '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
339        >>> read_obj = MyClass.load(json.loads(s))
340        >>> read_obj == my_obj
341        True
342
343    This isn't too impressive on its own, but it gets more useful when you have nested classses,
344    or fields that are not json-serializable by default:
345
346    ```python
347    @serializable_dataclass
348    class NestedClass(SerializableDataclass):
349        x: str
350        y: MyClass
351        act_fun: torch.nn.Module = serializable_field(
352            default=torch.nn.ReLU(),
353            serialization_fn=lambda x: str(x),
354            deserialize_fn=lambda x: getattr(torch.nn, x)(),
355        )
356    ```
357
358    which gives us:
359
360        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
361        >>> s = json.dumps(nc.serialize())
362        >>> s
363        '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
364        >>> read_nc = NestedClass.load(json.loads(s))
365        >>> read_nc == nc
366        True
367    """
368
369    def serialize(self) -> dict[str, Any]:
370        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
371        raise NotImplementedError(
372            f"decorate {self.__class__ = } with `@serializable_dataclass`"
373        )
374
375    @classmethod
376    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
377        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
378        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
379
380    def validate_fields_types(
381        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
382    ) -> bool:
383        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
384        return SerializableDataclass__validate_fields_types(
385            self, on_typecheck_error=on_typecheck_error
386        )
387
388    def validate_field_type(
389        self,
390        field: "SerializableField|str",
391        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
392    ) -> bool:
393        """given a dataclass, check the field matches the type hint"""
394        return SerializableDataclass__validate_field_type(
395            self, field, on_typecheck_error=on_typecheck_error
396        )
397
398    def __eq__(self, other: Any) -> bool:
399        return dc_eq(self, other)
400
401    def __hash__(self) -> int:
402        "hashes the json-serialized representation of the class"
403        return hash(json.dumps(self.serialize()))
404
405    def diff(
406        self, other: "SerializableDataclass", of_serialized: bool = False
407    ) -> dict[str, Any]:
408        """get a rich and recursive diff between two instances of a serializable dataclass
409
410        ```python
411        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
412        {'b': {'self': 2, 'other': 3}}
413        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
414        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
415        ```
416
417        # Parameters:
418         - `other : SerializableDataclass`
419           other instance to compare against
420         - `of_serialized : bool`
421           if true, compare serialized data and not raw values
422           (defaults to `False`)
423
424        # Returns:
425         - `dict[str, Any]`
426
427
428        # Raises:
429         - `ValueError` : if the instances are not of the same type
430         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
431        """
432        # match types
433        if type(self) is not type(other):
434            raise ValueError(
435                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
436            )
437
438        # initialize the diff result
439        diff_result: dict = {}
440
441        # if they are the same, return the empty diff
442        if self == other:
443            return diff_result
444
445        # if we are working with serialized data, serialize the instances
446        if of_serialized:
447            ser_self: dict = self.serialize()
448            ser_other: dict = other.serialize()
449
450        # for each field in the class
451        for field in dataclasses.fields(self):  # type: ignore[arg-type]
452            # skip fields that are not for comparison
453            if not field.compare:
454                continue
455
456            # get values
457            field_name: str = field.name
458            self_value = getattr(self, field_name)
459            other_value = getattr(other, field_name)
460
461            # if the values are both serializable dataclasses, recurse
462            if isinstance(self_value, SerializableDataclass) and isinstance(
463                other_value, SerializableDataclass
464            ):
465                nested_diff: dict = self_value.diff(
466                    other_value, of_serialized=of_serialized
467                )
468                if nested_diff:
469                    diff_result[field_name] = nested_diff
470            # only support serializable dataclasses
471            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
472                other_value
473            ):
474                raise ValueError("Non-serializable dataclass is not supported")
475            else:
476                # get the values of either the serialized or the actual values
477                self_value_s = ser_self[field_name] if of_serialized else self_value
478                other_value_s = ser_other[field_name] if of_serialized else other_value
479                # compare the values
480                if not array_safe_eq(self_value_s, other_value_s):
481                    diff_result[field_name] = {"self": self_value, "other": other_value}
482
483        # return the diff result
484        return diff_result
485
486    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
487        """update the instance from a nested dict, useful for configuration from command line args
488
489        # Parameters:
490            - `nested_dict : dict[str, Any]`
491                nested dict to update the instance with
492        """
493        for field in dataclasses.fields(self):  # type: ignore[arg-type]
494            field_name: str = field.name
495            self_value = getattr(self, field_name)
496
497            if field_name in nested_dict:
498                if isinstance(self_value, SerializableDataclass):
499                    self_value.update_from_nested_dict(nested_dict[field_name])
500                else:
501                    setattr(self, field_name, nested_dict[field_name])
502
503    def __copy__(self) -> "SerializableDataclass":
504        "deep copy by serializing and loading the instance to json"
505        return self.__class__.load(json.loads(json.dumps(self.serialize())))
506
507    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
508        "deep copy by serializing and loading the instance to json"
509        return self.__class__.load(json.loads(json.dumps(self.serialize())))

Base class for serializable dataclasses

only for linting and type checking, still need to call serializable_dataclass decorator

Usage:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize(self) -> dict[str, typing.Any]:
369    def serialize(self) -> dict[str, Any]:
370        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
371        raise NotImplementedError(
372            f"decorate {self.__class__ = } with `@serializable_dataclass`"
373        )

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls: Type[~T], data: Union[dict[str, Any], ~T]) -> ~T:
375    @classmethod
376    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
377        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
378        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
380    def validate_fields_types(
381        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
382    ) -> bool:
383        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
384        return SerializableDataclass__validate_fields_types(
385            self, on_typecheck_error=on_typecheck_error
386        )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

def validate_field_type( self, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
388    def validate_field_type(
389        self,
390        field: "SerializableField|str",
391        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
392    ) -> bool:
393        """given a dataclass, check the field matches the type hint"""
394        return SerializableDataclass__validate_field_type(
395            self, field, on_typecheck_error=on_typecheck_error
396        )

given a dataclass, check the field matches the type hint

def diff( self, other: SerializableDataclass, of_serialized: bool = False) -> dict[str, typing.Any]:
405    def diff(
406        self, other: "SerializableDataclass", of_serialized: bool = False
407    ) -> dict[str, Any]:
408        """get a rich and recursive diff between two instances of a serializable dataclass
409
410        ```python
411        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
412        {'b': {'self': 2, 'other': 3}}
413        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
414        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
415        ```
416
417        # Parameters:
418         - `other : SerializableDataclass`
419           other instance to compare against
420         - `of_serialized : bool`
421           if true, compare serialized data and not raw values
422           (defaults to `False`)
423
424        # Returns:
425         - `dict[str, Any]`
426
427
428        # Raises:
429         - `ValueError` : if the instances are not of the same type
430         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
431        """
432        # match types
433        if type(self) is not type(other):
434            raise ValueError(
435                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
436            )
437
438        # initialize the diff result
439        diff_result: dict = {}
440
441        # if they are the same, return the empty diff
442        if self == other:
443            return diff_result
444
445        # if we are working with serialized data, serialize the instances
446        if of_serialized:
447            ser_self: dict = self.serialize()
448            ser_other: dict = other.serialize()
449
450        # for each field in the class
451        for field in dataclasses.fields(self):  # type: ignore[arg-type]
452            # skip fields that are not for comparison
453            if not field.compare:
454                continue
455
456            # get values
457            field_name: str = field.name
458            self_value = getattr(self, field_name)
459            other_value = getattr(other, field_name)
460
461            # if the values are both serializable dataclasses, recurse
462            if isinstance(self_value, SerializableDataclass) and isinstance(
463                other_value, SerializableDataclass
464            ):
465                nested_diff: dict = self_value.diff(
466                    other_value, of_serialized=of_serialized
467                )
468                if nested_diff:
469                    diff_result[field_name] = nested_diff
470            # only support serializable dataclasses
471            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
472                other_value
473            ):
474                raise ValueError("Non-serializable dataclass is not supported")
475            else:
476                # get the values of either the serialized or the actual values
477                self_value_s = ser_self[field_name] if of_serialized else self_value
478                other_value_s = ser_other[field_name] if of_serialized else other_value
479                # compare the values
480                if not array_safe_eq(self_value_s, other_value_s):
481                    diff_result[field_name] = {"self": self_value, "other": other_value}
482
483        # return the diff result
484        return diff_result

get a rich and recursive diff between two instances of a serializable dataclass

>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}

Parameters:

  • other : SerializableDataclass other instance to compare against
  • of_serialized : bool if true, compare serialized data and not raw values (defaults to False)

Returns:

  • dict[str, Any]

Raises:

  • ValueError : if the instances are not of the same type
  • ValueError : if the instances are dataclasses.dataclass but not SerializableDataclass
def update_from_nested_dict(self, nested_dict: dict[str, typing.Any]):
486    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
487        """update the instance from a nested dict, useful for configuration from command line args
488
489        # Parameters:
490            - `nested_dict : dict[str, Any]`
491                nested dict to update the instance with
492        """
493        for field in dataclasses.fields(self):  # type: ignore[arg-type]
494            field_name: str = field.name
495            self_value = getattr(self, field_name)
496
497            if field_name in nested_dict:
498                if isinstance(self_value, SerializableDataclass):
499                    self_value.update_from_nested_dict(nested_dict[field_name])
500                else:
501                    setattr(self, field_name, nested_dict[field_name])

update the instance from a nested dict, useful for configuration from command line args

Parameters:

- `nested_dict : dict[str, Any]`
    nested dict to update the instance with
@functools.lru_cache(typed=True)
def get_cls_type_hints_cached(cls: Type[~T]) -> dict[str, typing.Any]:
514@functools.lru_cache(typed=True)
515def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]:
516    "cached typing.get_type_hints for a class"
517    return typing.get_type_hints(cls)

cached typing.get_type_hints for a class

def get_cls_type_hints(cls: Type[~T]) -> dict[str, typing.Any]:
520def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]:
521    "helper function to get type hints for a class"
522    cls_type_hints: dict[str, Any]
523    try:
524        cls_type_hints = get_cls_type_hints_cached(cls)  # type: ignore
525        if len(cls_type_hints) == 0:
526            cls_type_hints = typing.get_type_hints(cls)
527
528        if len(cls_type_hints) == 0:
529            raise ValueError(f"empty type hints for {cls.__name__ = }")
530    except (TypeError, NameError, ValueError) as e:
531        raise TypeError(
532            f"Cannot get type hints for {cls = }\n"
533            + f"  Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
534            + f"  {dataclasses.fields(cls) = }\n"  # type: ignore[arg-type]
535            + f"  {e = }"
536        ) from e
537
538    return cls_type_hints

helper function to get type hints for a class

class KWOnlyError(builtins.NotImplementedError):
541class KWOnlyError(NotImplementedError):
542    "kw-only dataclasses are not supported in python <3.9"
543
544    pass

kw-only dataclasses are not supported in python <3.9

Inherited Members
builtins.NotImplementedError
NotImplementedError
builtins.BaseException
with_traceback
add_note
args
class FieldError(builtins.ValueError):
547class FieldError(ValueError):
548    "base class for field errors"
549
550    pass

base class for field errors

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class NotSerializableFieldException(FieldError):
553class NotSerializableFieldException(FieldError):
554    "field is not a `SerializableField`"
555
556    pass

field is not a SerializableField

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldSerializationError(FieldError):
559class FieldSerializationError(FieldError):
560    "error while serializing a field"
561
562    pass

error while serializing a field

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldLoadingError(FieldError):
565class FieldLoadingError(FieldError):
566    "error while loading a field"
567
568    pass

error while loading a field

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldTypeMismatchError(FieldError, builtins.TypeError):
571class FieldTypeMismatchError(FieldError, TypeError):
572    "error when a field type does not match the type hint"
573
574    pass

error when a field type does not match the type hint

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
def serializable_dataclass( _cls=None, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, properties_to_serialize: Optional[list[str]] = None, register_handler: bool = True, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except, on_typecheck_mismatch: muutils.errormode.ErrorMode = ErrorMode.Warn, **kwargs):
577@dataclass_transform(
578    field_specifiers=(serializable_field, SerializableField),
579)
580def serializable_dataclass(
581    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
582    _cls=None,  # type: ignore
583    *,
584    init: bool = True,
585    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
586    eq: bool = True,
587    order: bool = False,
588    unsafe_hash: bool = False,
589    frozen: bool = False,
590    properties_to_serialize: Optional[list[str]] = None,
591    register_handler: bool = True,
592    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
593    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
594    **kwargs,
595):
596    """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass`
597
598    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
599
600    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs
601
602    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
603
604    Examines PEP 526 `__annotations__` to determine fields.
605
606    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
607
608    ```python
609    @serializable_dataclass(kw_only=True)
610    class Myclass(SerializableDataclass):
611        a: int
612        b: str
613    ```
614    ```python
615    >>> Myclass(a=1, b="q").serialize()
616    {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
617    ```
618
619    # Parameters:
620     - `_cls : _type_`
621       class to decorate. don't pass this arg, just use this as a decorator
622       (defaults to `None`)
623     - `init : bool`
624       (defaults to `True`)
625     - `repr : bool`
626       (defaults to `True`)
627     - `order : bool`
628       (defaults to `False`)
629     - `unsafe_hash : bool`
630       (defaults to `False`)
631     - `frozen : bool`
632       (defaults to `False`)
633     - `properties_to_serialize : Optional[list[str]]`
634       **SerializableDataclass only:** which properties to add to the serialized data dict
635       (defaults to `None`)
636     - `register_handler : bool`
637        **SerializableDataclass only:** if true, register the class with ZANJ for loading
638       (defaults to `True`)
639     - `on_typecheck_error : ErrorMode`
640        **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
641     - `on_typecheck_mismatch : ErrorMode`
642        **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
643
644    # Returns:
645     - `_type_`
646       the decorated class
647
648    # Raises:
649     - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
650     - `NotSerializableFieldException` : if a field is not a `SerializableField`
651     - `FieldSerializationError` : if there is an error serializing a field
652     - `AttributeError` : if a property is not found on the class
653     - `FieldLoadingError` : if there is an error loading a field
654    """
655    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
656    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
657    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
658
659    if properties_to_serialize is None:
660        _properties_to_serialize: list = list()
661    else:
662        _properties_to_serialize = properties_to_serialize
663
664    def wrap(cls: Type[T]) -> Type[T]:
665        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
666        for field_name, field_type in cls.__annotations__.items():
667            field_value = getattr(cls, field_name, None)
668            if not isinstance(field_value, SerializableField):
669                if isinstance(field_value, dataclasses.Field):
670                    # Convert the field to a SerializableField while preserving properties
671                    field_value = SerializableField.from_Field(field_value)
672                else:
673                    # Create a new SerializableField
674                    field_value = serializable_field()
675                setattr(cls, field_name, field_value)
676
677        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
678        if sys.version_info < (3, 10):
679            if "kw_only" in kwargs:
680                if kwargs["kw_only"] == True:  # noqa: E712
681                    raise KWOnlyError("kw_only is not supported in python >=3.9")
682                else:
683                    del kwargs["kw_only"]
684
685        # call `dataclasses.dataclass` to set some stuff up
686        cls = dataclasses.dataclass(  # type: ignore[call-overload]
687            cls,
688            init=init,
689            repr=repr,
690            eq=eq,
691            order=order,
692            unsafe_hash=unsafe_hash,
693            frozen=frozen,
694            **kwargs,
695        )
696
697        # copy these to the class
698        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
699
700        # ======================================================================
701        # define `serialize` func
702        # done locally since it depends on args to the decorator
703        # ======================================================================
704        def serialize(self) -> dict[str, Any]:
705            result: dict[str, Any] = {
706                "__format__": f"{self.__class__.__name__}(SerializableDataclass)"
707            }
708            # for each field in the class
709            for field in dataclasses.fields(self):  # type: ignore[arg-type]
710                # need it to be our special SerializableField
711                if not isinstance(field, SerializableField):
712                    raise NotSerializableFieldException(
713                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
714                        f"but a {type(field)} "
715                        "this state should be inaccessible, please report this bug!"
716                    )
717
718                # try to save it
719                if field.serialize:
720                    try:
721                        # get the val
722                        value = getattr(self, field.name)
723                        # if it is a serializable dataclass, serialize it
724                        if isinstance(value, SerializableDataclass):
725                            value = value.serialize()
726                        # if the value has a serialization function, use that
727                        if hasattr(value, "serialize") and callable(value.serialize):
728                            value = value.serialize()
729                        # if the field has a serialization function, use that
730                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
731                        elif field.serialization_fn:
732                            value = field.serialization_fn(value)
733
734                        # store the value in the result
735                        result[field.name] = value
736                    except Exception as e:
737                        raise FieldSerializationError(
738                            "\n".join(
739                                [
740                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
741                                    f"{field = }",
742                                    f"{value = }",
743                                    f"{self = }",
744                                ]
745                            )
746                        ) from e
747
748            # store each property if we can get it
749            for prop in self._properties_to_serialize:
750                if hasattr(cls, prop):
751                    value = getattr(self, prop)
752                    result[prop] = value
753                else:
754                    raise AttributeError(
755                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
756                        + f"but it is in {self._properties_to_serialize = }"
757                        + f"\n{self = }"
758                    )
759
760            return result
761
762        # ======================================================================
763        # define `load` func
764        # done locally since it depends on args to the decorator
765        # ======================================================================
766        # mypy thinks this isnt a classmethod
767        @classmethod  # type: ignore[misc]
768        def load(cls, data: dict[str, Any] | T) -> Type[T]:
769            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
770            if isinstance(data, cls):
771                return data
772
773            assert isinstance(
774                data, typing.Mapping
775            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
776
777            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
778
779            # initialize dict for keeping what we will pass to the constructor
780            ctor_kwargs: dict[str, Any] = dict()
781
782            # iterate over the fields of the class
783            for field in dataclasses.fields(cls):
784                # check if the field is a SerializableField
785                assert isinstance(
786                    field, SerializableField
787                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
788
789                # check if the field is in the data and if it should be initialized
790                if (field.name in data) and field.init:
791                    # get the value, we will be processing it
792                    value: Any = data[field.name]
793
794                    # get the type hint for the field
795                    field_type_hint: Any = cls_type_hints.get(field.name, None)
796
797                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
798                    if field.deserialize_fn:
799                        # if it has a deserialization function, use that
800                        value = field.deserialize_fn(value)
801                    elif field.loading_fn:
802                        # if it has a loading function, use that
803                        value = field.loading_fn(data)
804                    elif (
805                        field_type_hint is not None
806                        and hasattr(field_type_hint, "load")
807                        and callable(field_type_hint.load)
808                    ):
809                        # if no loading function but has a type hint with a load method, use that
810                        if isinstance(value, dict):
811                            value = field_type_hint.load(value)
812                        else:
813                            raise FieldLoadingError(
814                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
815                            )
816                    else:
817                        # assume no loading needs to happen, keep `value` as-is
818                        pass
819
820                    # store the value in the constructor kwargs
821                    ctor_kwargs[field.name] = value
822
823            # create a new instance of the class with the constructor kwargs
824            output: cls = cls(**ctor_kwargs)
825
826            # validate the types of the fields if needed
827            if on_typecheck_mismatch != ErrorMode.IGNORE:
828                fields_valid: dict[str, bool] = (
829                    SerializableDataclass__validate_fields_types__dict(
830                        output,
831                        on_typecheck_error=on_typecheck_error,
832                    )
833                )
834
835                # if there are any fields that are not valid, raise an error
836                if not all(fields_valid.values()):
837                    msg: str = (
838                        f"Type mismatch in fields of {cls.__name__}:\n"
839                        + "\n".join(
840                            [
841                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
842                                for k, v in fields_valid.items()
843                                if not v
844                            ]
845                        )
846                    )
847
848                    on_typecheck_mismatch.process(
849                        msg, except_cls=FieldTypeMismatchError
850                    )
851
852            # return the new instance
853            return output
854
855        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
856        # type is `Callable[[T], dict]`
857        cls.serialize = serialize  # type: ignore[attr-defined]
858        # type is `Callable[[dict], T]`
859        cls.load = load  # type: ignore[attr-defined]
860        # type is `Callable[[T, ErrorMode], bool]`
861        cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined]
862
863        # type is `Callable[[T, T], bool]`
864        if not hasattr(cls, "__eq__"):
865            cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
866
867        # Register the class with ZANJ
868        if register_handler:
869            zanj_register_loader_serializable_dataclass(cls)
870
871        return cls
872
873    if _cls is None:
874        return wrap
875    else:
876        return wrap(_cls)

decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass

types will be validated (like pydantic) unless on_typecheck_mismatch is set to ErrorMode.IGNORE

behavior of most kwargs matches that of dataclasses.dataclass, but with some additional kwargs

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

Examines PEP 526 __annotations__ to determine fields.

If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation.

@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
    a: int
    b: str
>>> Myclass(a=1, b="q").serialize()
{'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}

Parameters:

  • _cls : _type_ class to decorate. don't pass this arg, just use this as a decorator (defaults to None)
  • init : bool (defaults to True)
  • repr : bool (defaults to True)
  • order : bool (defaults to False)
  • unsafe_hash : bool (defaults to False)
  • frozen : bool (defaults to False)
  • properties_to_serialize : Optional[list[str]] SerializableDataclass only: which properties to add to the serialized data dict (defaults to None)
  • register_handler : bool SerializableDataclass only: if true, register the class with ZANJ for loading (defaults to True)
  • on_typecheck_error : ErrorMode SerializableDataclass only: what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, type validation will still return false
  • on_typecheck_mismatch : ErrorMode SerializableDataclass only: what to do if a type mismatch is found (except, warn, ignore). If ignore, type validation will return True

Returns:

  • _type_ the decorated class

Raises: