docs for muutils v0.6.17
View Source on GitHub

muutils.json_serialize

submodule for serializing things to json in a recoverable way

you can throw any object into muutils.json_serialize.json_serialize and it will return a JSONitem, meaning a bool, int, float, str, None, list of JSONitems, or a dict mappting to JSONitem.

The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into json_serialize and it will just work. If you want to do so in a recoverable way, check out ZANJ.

it will do so by looking in DEFAULT_HANDLERS, which will keep it as-is if its already valid, then try to find a .serialize() method on the object, and then have a bunch of special cases. You can add handlers by initializing a JsonSerializer object and passing a sequence of them to handlers_pre

additionally, SerializeableDataclass is a special kind of dataclass where you specify how to serialize each field, and a .serialize() method is automatically added to the class. This is done by using the serializable_dataclass decorator, inheriting from SerializeableDataclass, and serializable_field in place of dataclasses.field when defining non-standard fields.

This module plays nicely with and is a dependency of the ZANJ library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.


 1"""submodule for serializing things to json in a recoverable way
 2
 3you can throw *any* object into `muutils.json_serialize.json_serialize`
 4and it will return a `JSONitem`, meaning a bool, int, float, str, None, list of `JSONitem`s, or a dict mappting to `JSONitem`.
 5
 6The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into `json_serialize` and it will just work. If you want to do so in a recoverable way, check out [`ZANJ`](https://github.com/mivanit/ZANJ).
 7
 8it will do so by looking in `DEFAULT_HANDLERS`, which will keep it as-is if its already valid, then try to find a `.serialize()` method on the object, and then have a bunch of special cases. You can add handlers by initializing a `JsonSerializer` object and passing a sequence of them to `handlers_pre`
 9
10additionally, `SerializeableDataclass` is a special kind of dataclass where you specify how to serialize each field, and a `.serialize()` method is automatically added to the class. This is done by using the `serializable_dataclass` decorator, inheriting from `SerializeableDataclass`, and `serializable_field` in place of `dataclasses.field` when defining non-standard fields.
11
12This module plays nicely with and is a dependency of the [`ZANJ`](https://github.com/mivanit/ZANJ) library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.
13
14"""
15
16from __future__ import annotations
17
18from muutils.json_serialize.array import arr_metadata, load_array
19from muutils.json_serialize.json_serialize import (
20    BASE_HANDLERS,
21    JsonSerializer,
22    json_serialize,
23)
24from muutils.json_serialize.serializable_dataclass import (
25    SerializableDataclass,
26    serializable_dataclass,
27    serializable_field,
28)
29from muutils.json_serialize.util import try_catch, JSONitem, dc_eq
30
31__all__ = [
32    # submodules
33    "array",
34    "json_serialize",
35    "serializable_dataclass",
36    "serializable_field",
37    "util",
38    # imports
39    "arr_metadata",
40    "load_array",
41    "BASE_HANDLERS",
42    "JSONitem",
43    "JsonSerializer",
44    "json_serialize",
45    "try_catch",
46    "JSONitem",
47    "dc_eq",
48    "serializable_dataclass",
49    "serializable_field",
50    "SerializableDataclass",
51]

def json_serialize( obj: Any, path: tuple[typing.Union[str, int], ...] = ()) -> Union[bool, int, float, str, list, Dict[str, Any], NoneType]:
330def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem:
331    """serialize object to json-serializable object with default config"""
332    return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)

serialize object to json-serializable object with default config

@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
def serializable_dataclass( _cls=None, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, properties_to_serialize: Optional[list[str]] = None, register_handler: bool = True, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except, on_typecheck_mismatch: muutils.errormode.ErrorMode = ErrorMode.Warn, **kwargs):
577@dataclass_transform(
578    field_specifiers=(serializable_field, SerializableField),
579)
580def serializable_dataclass(
581    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
582    _cls=None,  # type: ignore
583    *,
584    init: bool = True,
585    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
586    eq: bool = True,
587    order: bool = False,
588    unsafe_hash: bool = False,
589    frozen: bool = False,
590    properties_to_serialize: Optional[list[str]] = None,
591    register_handler: bool = True,
592    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
593    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
594    **kwargs,
595):
596    """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass`
597
598    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
599
600    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs
601
602    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
603
604    Examines PEP 526 `__annotations__` to determine fields.
605
606    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
607
608    ```python
609    @serializable_dataclass(kw_only=True)
610    class Myclass(SerializableDataclass):
611        a: int
612        b: str
613    ```
614    ```python
615    >>> Myclass(a=1, b="q").serialize()
616    {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
617    ```
618
619    # Parameters:
620     - `_cls : _type_`
621       class to decorate. don't pass this arg, just use this as a decorator
622       (defaults to `None`)
623     - `init : bool`
624       (defaults to `True`)
625     - `repr : bool`
626       (defaults to `True`)
627     - `order : bool`
628       (defaults to `False`)
629     - `unsafe_hash : bool`
630       (defaults to `False`)
631     - `frozen : bool`
632       (defaults to `False`)
633     - `properties_to_serialize : Optional[list[str]]`
634       **SerializableDataclass only:** which properties to add to the serialized data dict
635       (defaults to `None`)
636     - `register_handler : bool`
637        **SerializableDataclass only:** if true, register the class with ZANJ for loading
638       (defaults to `True`)
639     - `on_typecheck_error : ErrorMode`
640        **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
641     - `on_typecheck_mismatch : ErrorMode`
642        **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
643
644    # Returns:
645     - `_type_`
646       the decorated class
647
648    # Raises:
649     - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
650     - `NotSerializableFieldException` : if a field is not a `SerializableField`
651     - `FieldSerializationError` : if there is an error serializing a field
652     - `AttributeError` : if a property is not found on the class
653     - `FieldLoadingError` : if there is an error loading a field
654    """
655    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
656    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
657    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
658
659    if properties_to_serialize is None:
660        _properties_to_serialize: list = list()
661    else:
662        _properties_to_serialize = properties_to_serialize
663
664    def wrap(cls: Type[T]) -> Type[T]:
665        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
666        for field_name, field_type in cls.__annotations__.items():
667            field_value = getattr(cls, field_name, None)
668            if not isinstance(field_value, SerializableField):
669                if isinstance(field_value, dataclasses.Field):
670                    # Convert the field to a SerializableField while preserving properties
671                    field_value = SerializableField.from_Field(field_value)
672                else:
673                    # Create a new SerializableField
674                    field_value = serializable_field()
675                setattr(cls, field_name, field_value)
676
677        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
678        if sys.version_info < (3, 10):
679            if "kw_only" in kwargs:
680                if kwargs["kw_only"] == True:  # noqa: E712
681                    raise KWOnlyError("kw_only is not supported in python >=3.9")
682                else:
683                    del kwargs["kw_only"]
684
685        # call `dataclasses.dataclass` to set some stuff up
686        cls = dataclasses.dataclass(  # type: ignore[call-overload]
687            cls,
688            init=init,
689            repr=repr,
690            eq=eq,
691            order=order,
692            unsafe_hash=unsafe_hash,
693            frozen=frozen,
694            **kwargs,
695        )
696
697        # copy these to the class
698        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
699
700        # ======================================================================
701        # define `serialize` func
702        # done locally since it depends on args to the decorator
703        # ======================================================================
704        def serialize(self) -> dict[str, Any]:
705            result: dict[str, Any] = {
706                "__format__": f"{self.__class__.__name__}(SerializableDataclass)"
707            }
708            # for each field in the class
709            for field in dataclasses.fields(self):  # type: ignore[arg-type]
710                # need it to be our special SerializableField
711                if not isinstance(field, SerializableField):
712                    raise NotSerializableFieldException(
713                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
714                        f"but a {type(field)} "
715                        "this state should be inaccessible, please report this bug!"
716                    )
717
718                # try to save it
719                if field.serialize:
720                    try:
721                        # get the val
722                        value = getattr(self, field.name)
723                        # if it is a serializable dataclass, serialize it
724                        if isinstance(value, SerializableDataclass):
725                            value = value.serialize()
726                        # if the value has a serialization function, use that
727                        if hasattr(value, "serialize") and callable(value.serialize):
728                            value = value.serialize()
729                        # if the field has a serialization function, use that
730                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
731                        elif field.serialization_fn:
732                            value = field.serialization_fn(value)
733
734                        # store the value in the result
735                        result[field.name] = value
736                    except Exception as e:
737                        raise FieldSerializationError(
738                            "\n".join(
739                                [
740                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
741                                    f"{field = }",
742                                    f"{value = }",
743                                    f"{self = }",
744                                ]
745                            )
746                        ) from e
747
748            # store each property if we can get it
749            for prop in self._properties_to_serialize:
750                if hasattr(cls, prop):
751                    value = getattr(self, prop)
752                    result[prop] = value
753                else:
754                    raise AttributeError(
755                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
756                        + f"but it is in {self._properties_to_serialize = }"
757                        + f"\n{self = }"
758                    )
759
760            return result
761
762        # ======================================================================
763        # define `load` func
764        # done locally since it depends on args to the decorator
765        # ======================================================================
766        # mypy thinks this isnt a classmethod
767        @classmethod  # type: ignore[misc]
768        def load(cls, data: dict[str, Any] | T) -> Type[T]:
769            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
770            if isinstance(data, cls):
771                return data
772
773            assert isinstance(
774                data, typing.Mapping
775            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
776
777            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
778
779            # initialize dict for keeping what we will pass to the constructor
780            ctor_kwargs: dict[str, Any] = dict()
781
782            # iterate over the fields of the class
783            for field in dataclasses.fields(cls):
784                # check if the field is a SerializableField
785                assert isinstance(
786                    field, SerializableField
787                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
788
789                # check if the field is in the data and if it should be initialized
790                if (field.name in data) and field.init:
791                    # get the value, we will be processing it
792                    value: Any = data[field.name]
793
794                    # get the type hint for the field
795                    field_type_hint: Any = cls_type_hints.get(field.name, None)
796
797                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
798                    if field.deserialize_fn:
799                        # if it has a deserialization function, use that
800                        value = field.deserialize_fn(value)
801                    elif field.loading_fn:
802                        # if it has a loading function, use that
803                        value = field.loading_fn(data)
804                    elif (
805                        field_type_hint is not None
806                        and hasattr(field_type_hint, "load")
807                        and callable(field_type_hint.load)
808                    ):
809                        # if no loading function but has a type hint with a load method, use that
810                        if isinstance(value, dict):
811                            value = field_type_hint.load(value)
812                        else:
813                            raise FieldLoadingError(
814                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
815                            )
816                    else:
817                        # assume no loading needs to happen, keep `value` as-is
818                        pass
819
820                    # store the value in the constructor kwargs
821                    ctor_kwargs[field.name] = value
822
823            # create a new instance of the class with the constructor kwargs
824            output: cls = cls(**ctor_kwargs)
825
826            # validate the types of the fields if needed
827            if on_typecheck_mismatch != ErrorMode.IGNORE:
828                fields_valid: dict[str, bool] = (
829                    SerializableDataclass__validate_fields_types__dict(
830                        output,
831                        on_typecheck_error=on_typecheck_error,
832                    )
833                )
834
835                # if there are any fields that are not valid, raise an error
836                if not all(fields_valid.values()):
837                    msg: str = (
838                        f"Type mismatch in fields of {cls.__name__}:\n"
839                        + "\n".join(
840                            [
841                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
842                                for k, v in fields_valid.items()
843                                if not v
844                            ]
845                        )
846                    )
847
848                    on_typecheck_mismatch.process(
849                        msg, except_cls=FieldTypeMismatchError
850                    )
851
852            # return the new instance
853            return output
854
855        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
856        # type is `Callable[[T], dict]`
857        cls.serialize = serialize  # type: ignore[attr-defined]
858        # type is `Callable[[dict], T]`
859        cls.load = load  # type: ignore[attr-defined]
860        # type is `Callable[[T, ErrorMode], bool]`
861        cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined]
862
863        # type is `Callable[[T, T], bool]`
864        if not hasattr(cls, "__eq__"):
865            cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
866
867        # Register the class with ZANJ
868        if register_handler:
869            zanj_register_loader_serializable_dataclass(cls)
870
871        return cls
872
873    if _cls is None:
874        return wrap
875    else:
876        return wrap(_cls)

decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass

types will be validated (like pydantic) unless on_typecheck_mismatch is set to ErrorMode.IGNORE

behavior of most kwargs matches that of dataclasses.dataclass, but with some additional kwargs

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

Examines PEP 526 __annotations__ to determine fields.

If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation.

@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
    a: int
    b: str
>>> Myclass(a=1, b="q").serialize()
{'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}

Parameters:

  • _cls : _type_ class to decorate. don't pass this arg, just use this as a decorator (defaults to None)
  • init : bool (defaults to True)
  • repr : bool (defaults to True)
  • order : bool (defaults to False)
  • unsafe_hash : bool (defaults to False)
  • frozen : bool (defaults to False)
  • properties_to_serialize : Optional[list[str]] SerializableDataclass only: which properties to add to the serialized data dict (defaults to None)
  • register_handler : bool SerializableDataclass only: if true, register the class with ZANJ for loading (defaults to True)
  • on_typecheck_error : ErrorMode SerializableDataclass only: what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, type validation will still return false
  • on_typecheck_mismatch : ErrorMode SerializableDataclass only: what to do if a type mismatch is found (except, warn, ignore). If ignore, type validation will return True

Returns:

  • _type_ the decorated class

Raises:

  • KWOnlyError : only raised if kw_only is True and python version is <3.9, since dataclasses.dataclass does not support this
  • NotSerializableFieldException : if a field is not a SerializableField
  • FieldSerializationError : if there is an error serializing a field
  • AttributeError : if a property is not found on the class
  • FieldLoadingError : if there is an error loading a field
def serializable_field( *_args, default: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>, default_factory: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, metadata: Optional[mappingproxy] = None, kw_only: Union[bool, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>, serialize: bool = True, serialization_fn: Optional[Callable[[Any], Any]] = None, deserialize_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, custom_typecheck_fn: Optional[Callable[[type], bool]] = None, **kwargs: Any) -> Any:
188def serializable_field(
189    *_args,
190    default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
191    default_factory: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
192    init: bool = True,
193    repr: bool = True,
194    hash: Optional[bool] = None,
195    compare: bool = True,
196    metadata: Optional[types.MappingProxyType] = None,
197    kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
198    serialize: bool = True,
199    serialization_fn: Optional[Callable[[Any], Any]] = None,
200    deserialize_fn: Optional[Callable[[Any], Any]] = None,
201    assert_type: bool = True,
202    custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
203    **kwargs: Any,
204) -> Any:
205    """Create a new `SerializableField`
206
207    ```
208    default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
209    default_factory: Callable[[], Sfield_T]
210    | dataclasses._MISSING_TYPE = dataclasses.MISSING,
211    init: bool = True,
212    repr: bool = True,
213    hash: Optional[bool] = None,
214    compare: bool = True,
215    metadata: types.MappingProxyType | None = None,
216    kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
217    # ----------------------------------------------------------------------
218    # new in `SerializableField`, not in `dataclasses.Field`
219    serialize: bool = True,
220    serialization_fn: Optional[Callable[[Any], Any]] = None,
221    loading_fn: Optional[Callable[[Any], Any]] = None,
222    deserialize_fn: Optional[Callable[[Any], Any]] = None,
223    assert_type: bool = True,
224    custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
225    ```
226
227    # new Parameters:
228    - `serialize`: whether to serialize this field when serializing the class'
229    - `serialization_fn`: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the `SerializerHandler`s defined in `muutils.json_serialize.json_serialize`
230    - `loading_fn`: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.
231    - `deserialize_fn`: new alternative to `loading_fn`. takes only the field's value, not the whole class. if both `loading_fn` and `deserialize_fn` are provided, an error will be raised.
232    - `assert_type`: whether to assert the type of the field when loading. if `False`, will not check the type of the field.
233    - `custom_typecheck_fn`: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking.
234
235    # Gotchas:
236    - `loading_fn` takes the dict of the **class**, not the field. if you wanted a `loading_fn` that does nothing, you'd write:
237
238    ```python
239    class MyClass:
240        my_field: int = serializable_field(
241            serialization_fn=lambda x: str(x),
242            loading_fn=lambda x["my_field"]: int(x)
243        )
244    ```
245
246    using `deserialize_fn` instead:
247
248    ```python
249    class MyClass:
250        my_field: int = serializable_field(
251            serialization_fn=lambda x: str(x),
252            deserialize_fn=lambda x: int(x)
253        )
254    ```
255
256    In the above code, `my_field` is an int but will be serialized as a string.
257
258    note that if not using ZANJ, and you have a class inside a container, you MUST provide
259    `serialization_fn` and `loading_fn` to serialize and load the container.
260    ZANJ will automatically do this for you.
261
262    # TODO: `custom_value_check_fn`: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test
263    """
264    assert len(_args) == 0, f"unexpected positional arguments: {_args}"
265    return SerializableField(
266        default=default,
267        default_factory=default_factory,
268        init=init,
269        repr=repr,
270        hash=hash,
271        compare=compare,
272        metadata=metadata,
273        kw_only=kw_only,
274        serialize=serialize,
275        serialization_fn=serialization_fn,
276        deserialize_fn=deserialize_fn,
277        assert_type=assert_type,
278        custom_typecheck_fn=custom_typecheck_fn,
279        **kwargs,
280    )

Create a new SerializableField

default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
# ----------------------------------------------------------------------
# new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,

new Parameters:

  • serialize: whether to serialize this field when serializing the class'
  • serialization_fn: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the SerializerHandlers defined in muutils.json_serialize.json_serialize
  • loading_fn: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.
  • deserialize_fn: new alternative to loading_fn. takes only the field's value, not the whole class. if both loading_fn and deserialize_fn are provided, an error will be raised.
  • assert_type: whether to assert the type of the field when loading. if False, will not check the type of the field.
  • custom_typecheck_fn: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking.

Gotchas:

  • loading_fn takes the dict of the class, not the field. if you wanted a loading_fn that does nothing, you'd write:
class MyClass:
    my_field: int = serializable_field(
        serialization_fn=lambda x: str(x),
        loading_fn=lambda x["my_field"]: int(x)
    )

using deserialize_fn instead:

class MyClass:
    my_field: int = serializable_field(
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: int(x)
    )

In the above code, my_field is an int but will be serialized as a string.

note that if not using ZANJ, and you have a class inside a container, you MUST provide serialization_fn and loading_fn to serialize and load the container. ZANJ will automatically do this for you.

TODO: custom_value_check_fn: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test

def arr_metadata(arr) -> dict[str, list[int] | str | int]:
49def arr_metadata(arr) -> dict[str, list[int] | str | int]:
50    """get metadata for a numpy array"""
51    return {
52        "shape": list(arr.shape),
53        "dtype": (
54            arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype)
55        ),
56        "n_elements": array_n_elements(arr),
57    }

get metadata for a numpy array

def load_array( arr: Union[bool, int, float, str, list, Dict[str, Any], NoneType], array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None) -> Any:
168def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any:
169    """load a json-serialized array, infer the mode if not specified"""
170    # return arr if its already a numpy array
171    if isinstance(arr, np.ndarray) and array_mode is None:
172        return arr
173
174    # try to infer the array_mode
175    array_mode_inferred: ArrayMode = infer_array_mode(arr)
176    if array_mode is None:
177        array_mode = array_mode_inferred
178    elif array_mode != array_mode_inferred:
179        warnings.warn(
180            f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}"
181        )
182
183    # actually load the array
184    if array_mode == "array_list_meta":
185        assert isinstance(
186            arr, typing.Mapping
187        ), f"invalid list format: {type(arr) = }\n{arr = }"
188
189        data = np.array(arr["data"], dtype=arr["dtype"])
190        if tuple(arr["shape"]) != tuple(data.shape):
191            raise ValueError(f"invalid shape: {arr}")
192        return data
193
194    elif array_mode == "array_hex_meta":
195        assert isinstance(
196            arr, typing.Mapping
197        ), f"invalid list format: {type(arr) = }\n{arr = }"
198
199        data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"])
200        return data.reshape(arr["shape"])
201
202    elif array_mode == "array_b64_meta":
203        assert isinstance(
204            arr, typing.Mapping
205        ), f"invalid list format: {type(arr) = }\n{arr = }"
206
207        data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"])
208        return data.reshape(arr["shape"])
209
210    elif array_mode == "list":
211        assert isinstance(
212            arr, typing.Sequence
213        ), f"invalid list format: {type(arr) = }\n{arr = }"
214
215        return np.array(arr)
216    elif array_mode == "external":
217        # assume ZANJ has taken care of it
218        assert isinstance(arr, typing.Mapping)
219        if "data" not in arr:
220            raise KeyError(
221                f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}"
222            )
223        return arr["data"]
224    elif array_mode == "zero_dim":
225        assert isinstance(arr, typing.Mapping)
226        data = np.array(arr["data"])
227        if tuple(arr["shape"]) != tuple(data.shape):
228            raise ValueError(f"invalid shape: {arr}")
229        return data
230    else:
231        raise ValueError(f"invalid array_mode: {array_mode}")

load a json-serialized array, infer the mode if not specified

BASE_HANDLERS = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'))
JSONitem = typing.Union[bool, int, float, str, list, typing.Dict[str, typing.Any], NoneType]
class JsonSerializer:
234class JsonSerializer:
235    """Json serialization class (holds configs)
236
237    # Parameters:
238    - `array_mode : ArrayMode`
239    how to write arrays
240    (defaults to `"array_list_meta"`)
241    - `error_mode : ErrorMode`
242    what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn")
243    (defaults to `"except"`)
244    - `handlers_pre : MonoTuple[SerializerHandler]`
245    handlers to use before the default handlers
246    (defaults to `tuple()`)
247    - `handlers_default : MonoTuple[SerializerHandler]`
248    default handlers to use
249    (defaults to `DEFAULT_HANDLERS`)
250    - `write_only_format : bool`
251    changes "__format__" keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading)
252    (defaults to `False`)
253
254    # Raises:
255    - `ValueError`: on init, if `args` is not empty
256    - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"`
257
258    """
259
260    def __init__(
261        self,
262        *args,
263        array_mode: ArrayMode = "array_list_meta",
264        error_mode: ErrorMode = ErrorMode.EXCEPT,
265        handlers_pre: MonoTuple[SerializerHandler] = tuple(),
266        handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS,
267        write_only_format: bool = False,
268    ):
269        if len(args) > 0:
270            raise ValueError(
271                f"JsonSerializer takes no positional arguments!\n{args = }"
272            )
273
274        self.array_mode: ArrayMode = array_mode
275        self.error_mode: ErrorMode = ErrorMode.from_any(error_mode)
276        self.write_only_format: bool = write_only_format
277        # join up the handlers
278        self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple(
279            handlers_default
280        )
281
282    def json_serialize(
283        self,
284        obj: Any,
285        path: ObjectPath = tuple(),
286    ) -> JSONitem:
287        try:
288            for handler in self.handlers:
289                if handler.check(self, obj, path):
290                    output: JSONitem = handler.serialize_func(self, obj, path)
291                    if self.write_only_format:
292                        if isinstance(output, dict) and "__format__" in output:
293                            new_fmt: JSONitem = output.pop("__format__")
294                            output["__write_format__"] = new_fmt
295                    return output
296
297            raise ValueError(f"no handler found for object with {type(obj) = }")
298
299        except Exception as e:
300            if self.error_mode == "except":
301                obj_str: str = repr(obj)
302                if len(obj_str) > 1000:
303                    obj_str = obj_str[:1000] + "..."
304                raise SerializationException(
305                    f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}"
306                ) from e
307            elif self.error_mode == "warn":
308                warnings.warn(
309                    f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}"
310                )
311
312            return repr(obj)
313
314    def hashify(
315        self,
316        obj: Any,
317        path: ObjectPath = tuple(),
318        force: bool = True,
319    ) -> Hashableitem:
320        """try to turn any object into something hashable"""
321        data = self.json_serialize(obj, path=path)
322
323        # recursive hashify, turning dicts and lists into tuples
324        return _recursive_hashify(data, force=force)

Json serialization class (holds configs)

Parameters:

  • array_mode : ArrayMode how to write arrays (defaults to "array_list_meta")
  • error_mode : ErrorMode what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") (defaults to "except")
  • handlers_pre : MonoTuple[SerializerHandler] handlers to use before the default handlers (defaults to tuple())
  • handlers_default : MonoTuple[SerializerHandler] default handlers to use (defaults to DEFAULT_HANDLERS)
  • write_only_format : bool changes "__format__" keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) (defaults to False)

Raises:

  • ValueError: on init, if args is not empty
  • SerializationException: on json_serialize(), if any error occurs when trying to serialize an object and error_mode is set to ErrorMode.EXCEPT"
JsonSerializer( *args, array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim'] = 'array_list_meta', error_mode: muutils.errormode.ErrorMode = ErrorMode.Except, handlers_pre: None = (), handlers_default: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(set, list, tuple, Iterable) -> list', desc='sets, lists, tuples, and Iterables as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings')), write_only_format: bool = False)
260    def __init__(
261        self,
262        *args,
263        array_mode: ArrayMode = "array_list_meta",
264        error_mode: ErrorMode = ErrorMode.EXCEPT,
265        handlers_pre: MonoTuple[SerializerHandler] = tuple(),
266        handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS,
267        write_only_format: bool = False,
268    ):
269        if len(args) > 0:
270            raise ValueError(
271                f"JsonSerializer takes no positional arguments!\n{args = }"
272            )
273
274        self.array_mode: ArrayMode = array_mode
275        self.error_mode: ErrorMode = ErrorMode.from_any(error_mode)
276        self.write_only_format: bool = write_only_format
277        # join up the handlers
278        self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple(
279            handlers_default
280        )
array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']
write_only_format: bool
handlers: None
def json_serialize( self, obj: Any, path: tuple[typing.Union[str, int], ...] = ()) -> Union[bool, int, float, str, list, Dict[str, Any], NoneType]:
282    def json_serialize(
283        self,
284        obj: Any,
285        path: ObjectPath = tuple(),
286    ) -> JSONitem:
287        try:
288            for handler in self.handlers:
289                if handler.check(self, obj, path):
290                    output: JSONitem = handler.serialize_func(self, obj, path)
291                    if self.write_only_format:
292                        if isinstance(output, dict) and "__format__" in output:
293                            new_fmt: JSONitem = output.pop("__format__")
294                            output["__write_format__"] = new_fmt
295                    return output
296
297            raise ValueError(f"no handler found for object with {type(obj) = }")
298
299        except Exception as e:
300            if self.error_mode == "except":
301                obj_str: str = repr(obj)
302                if len(obj_str) > 1000:
303                    obj_str = obj_str[:1000] + "..."
304                raise SerializationException(
305                    f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}"
306                ) from e
307            elif self.error_mode == "warn":
308                warnings.warn(
309                    f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}"
310                )
311
312            return repr(obj)
def hashify( self, obj: Any, path: tuple[typing.Union[str, int], ...] = (), force: bool = True) -> Union[bool, int, float, str, tuple]:
314    def hashify(
315        self,
316        obj: Any,
317        path: ObjectPath = tuple(),
318        force: bool = True,
319    ) -> Hashableitem:
320        """try to turn any object into something hashable"""
321        data = self.json_serialize(obj, path=path)
322
323        # recursive hashify, turning dicts and lists into tuples
324        return _recursive_hashify(data, force=force)

try to turn any object into something hashable

def try_catch(func: Callable):
81def try_catch(func: Callable):
82    """wraps the function to catch exceptions, returns serialized error message on exception
83
84    returned func will return normal result on success, or error message on exception
85    """
86
87    @functools.wraps(func)
88    def newfunc(*args, **kwargs):
89        try:
90            return func(*args, **kwargs)
91        except Exception as e:
92            return f"{e.__class__.__name__}: {e}"
93
94    return newfunc

wraps the function to catch exceptions, returns serialized error message on exception

returned func will return normal result on success, or error message on exception

def dc_eq( dc1, dc2, except_when_class_mismatch: bool = False, false_when_class_mismatch: bool = True, except_when_field_mismatch: bool = False) -> bool:
175def dc_eq(
176    dc1,
177    dc2,
178    except_when_class_mismatch: bool = False,
179    false_when_class_mismatch: bool = True,
180    except_when_field_mismatch: bool = False,
181) -> bool:
182    """
183    checks if two dataclasses which (might) hold numpy arrays are equal
184
185    # Parameters:
186
187    - `dc1`: the first dataclass
188    - `dc2`: the second dataclass
189    - `except_when_class_mismatch: bool`
190        if `True`, will throw `TypeError` if the classes are different.
191        if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False`
192        (default: `False`)
193    - `false_when_class_mismatch: bool`
194        only relevant if `except_when_class_mismatch` is `False`.
195        if `True`, will return `False` if the classes are different.
196        if `False`, will attempt to compare the fields.
197    - `except_when_field_mismatch: bool`
198        only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`.
199        if `True`, will throw `TypeError` if the fields are different.
200        (default: `True`)
201
202    # Returns:
203    - `bool`: True if the dataclasses are equal, False otherwise
204
205    # Raises:
206    - `TypeError`: if the dataclasses are of different classes
207    - `AttributeError`: if the dataclasses have different fields
208
209    # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?
210    ```
211              [START]
212
213           ┌───────────┐  ┌─────────┐
214           │dc1 is dc2?├─►│ classes │
215           └──┬────────┘No│ match?  │
216      ────    │           ├─────────┤
217     (True)◄──┘Yes        │No       │Yes
218      ────                ▼         ▼
219          ┌────────────────┐ ┌────────────┐
220          │ except when    │ │ fields keys│
221          │ class mismatch?│ │ match?     │
222          ├───────────┬────┘ ├───────┬────┘
223          │Yes        │No    │No     │Yes
224          ▼           ▼      ▼       ▼
225     ───────────  ┌──────────┐  ┌────────┐
226    { raise     } │ except   │  │ field  │
227    { TypeError } │ when     │  │ values │
228     ───────────  │ field    │  │ match? │
229                  │ mismatch?│  ├────┬───┘
230                  ├───────┬──┘  │    │Yes
231                  │Yes    │No   │No  ▼
232                  ▼       ▼     │   ────
233     ───────────────     ─────  │  (True)
234    { raise         }   (False)◄┘   ────
235    { AttributeError}    ─────
236     ───────────────
237    ```
238
239    """
240    if dc1 is dc2:
241        return True
242
243    if dc1.__class__ is not dc2.__class__:
244        if except_when_class_mismatch:
245            # if the classes don't match, raise an error
246            raise TypeError(
247                f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`"
248            )
249        if except_when_field_mismatch:
250            dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)])
251            dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)])
252            fields_match: bool = set(dc1_fields) == set(dc2_fields)
253            if not fields_match:
254                # if the fields match, keep going
255                raise AttributeError(
256                    f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`"
257                )
258        return False
259
260    return all(
261        array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name))
262        for fld in dataclasses.fields(dc1)
263        if fld.compare
264    )

checks if two dataclasses which (might) hold numpy arrays are equal

Parameters:

  • dc1: the first dataclass
  • dc2: the second dataclass
  • except_when_class_mismatch: bool if True, will throw TypeError if the classes are different. if not, will return false by default or attempt to compare the fields if false_when_class_mismatch is False (default: False)
  • false_when_class_mismatch: bool only relevant if except_when_class_mismatch is False. if True, will return False if the classes are different. if False, will attempt to compare the fields.
  • except_when_field_mismatch: bool only relevant if except_when_class_mismatch is False and false_when_class_mismatch is False. if True, will throw TypeError if the fields are different. (default: True)

Returns:

  • bool: True if the dataclasses are equal, False otherwise

Raises:

  • TypeError: if the dataclasses are of different classes
  • AttributeError: if the dataclasses have different fields

TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?

          [START]
             ▼
       ┌───────────┐  ┌─────────┐
       │dc1 is dc2?├─►│ classes │
       └──┬────────┘No│ match?  │
  ────    │           ├─────────┤
 (True)◄──┘Yes        │No       │Yes
  ────                ▼         ▼
      ┌────────────────┐ ┌────────────┐
      │ except when    │ │ fields keys│
      │ class mismatch?│ │ match?     │
      ├───────────┬────┘ ├───────┬────┘
      │Yes        │No    │No     │Yes
      ▼           ▼      ▼       ▼
 ───────────  ┌──────────┐  ┌────────┐
{ raise     } │ except   │  │ field  │
{ TypeError } │ when     │  │ values │
 ───────────  │ field    │  │ match? │
              │ mismatch?│  ├────┬───┘
              ├───────┬──┘  │    │Yes
              │Yes    │No   │No  ▼
              ▼       ▼     │   ────
 ───────────────     ─────  │  (True)
{ raise         }   (False)◄┘   ────
{ AttributeError}    ─────
 ───────────────
@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
class SerializableDataclass(abc.ABC):
316@dataclass_transform(
317    field_specifiers=(serializable_field, SerializableField),
318)
319class SerializableDataclass(abc.ABC):
320    """Base class for serializable dataclasses
321
322    only for linting and type checking, still need to call `serializable_dataclass` decorator
323
324    # Usage:
325
326    ```python
327    @serializable_dataclass
328    class MyClass(SerializableDataclass):
329        a: int
330        b: str
331    ```
332
333    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
334
335        >>> my_obj = MyClass(a=1, b="q")
336        >>> s = json.dumps(my_obj.serialize())
337        >>> s
338        '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
339        >>> read_obj = MyClass.load(json.loads(s))
340        >>> read_obj == my_obj
341        True
342
343    This isn't too impressive on its own, but it gets more useful when you have nested classses,
344    or fields that are not json-serializable by default:
345
346    ```python
347    @serializable_dataclass
348    class NestedClass(SerializableDataclass):
349        x: str
350        y: MyClass
351        act_fun: torch.nn.Module = serializable_field(
352            default=torch.nn.ReLU(),
353            serialization_fn=lambda x: str(x),
354            deserialize_fn=lambda x: getattr(torch.nn, x)(),
355        )
356    ```
357
358    which gives us:
359
360        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
361        >>> s = json.dumps(nc.serialize())
362        >>> s
363        '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
364        >>> read_nc = NestedClass.load(json.loads(s))
365        >>> read_nc == nc
366        True
367    """
368
369    def serialize(self) -> dict[str, Any]:
370        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
371        raise NotImplementedError(
372            f"decorate {self.__class__ = } with `@serializable_dataclass`"
373        )
374
375    @classmethod
376    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
377        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
378        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
379
380    def validate_fields_types(
381        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
382    ) -> bool:
383        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
384        return SerializableDataclass__validate_fields_types(
385            self, on_typecheck_error=on_typecheck_error
386        )
387
388    def validate_field_type(
389        self,
390        field: "SerializableField|str",
391        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
392    ) -> bool:
393        """given a dataclass, check the field matches the type hint"""
394        return SerializableDataclass__validate_field_type(
395            self, field, on_typecheck_error=on_typecheck_error
396        )
397
398    def __eq__(self, other: Any) -> bool:
399        return dc_eq(self, other)
400
401    def __hash__(self) -> int:
402        "hashes the json-serialized representation of the class"
403        return hash(json.dumps(self.serialize()))
404
405    def diff(
406        self, other: "SerializableDataclass", of_serialized: bool = False
407    ) -> dict[str, Any]:
408        """get a rich and recursive diff between two instances of a serializable dataclass
409
410        ```python
411        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
412        {'b': {'self': 2, 'other': 3}}
413        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
414        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
415        ```
416
417        # Parameters:
418         - `other : SerializableDataclass`
419           other instance to compare against
420         - `of_serialized : bool`
421           if true, compare serialized data and not raw values
422           (defaults to `False`)
423
424        # Returns:
425         - `dict[str, Any]`
426
427
428        # Raises:
429         - `ValueError` : if the instances are not of the same type
430         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
431        """
432        # match types
433        if type(self) is not type(other):
434            raise ValueError(
435                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
436            )
437
438        # initialize the diff result
439        diff_result: dict = {}
440
441        # if they are the same, return the empty diff
442        if self == other:
443            return diff_result
444
445        # if we are working with serialized data, serialize the instances
446        if of_serialized:
447            ser_self: dict = self.serialize()
448            ser_other: dict = other.serialize()
449
450        # for each field in the class
451        for field in dataclasses.fields(self):  # type: ignore[arg-type]
452            # skip fields that are not for comparison
453            if not field.compare:
454                continue
455
456            # get values
457            field_name: str = field.name
458            self_value = getattr(self, field_name)
459            other_value = getattr(other, field_name)
460
461            # if the values are both serializable dataclasses, recurse
462            if isinstance(self_value, SerializableDataclass) and isinstance(
463                other_value, SerializableDataclass
464            ):
465                nested_diff: dict = self_value.diff(
466                    other_value, of_serialized=of_serialized
467                )
468                if nested_diff:
469                    diff_result[field_name] = nested_diff
470            # only support serializable dataclasses
471            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
472                other_value
473            ):
474                raise ValueError("Non-serializable dataclass is not supported")
475            else:
476                # get the values of either the serialized or the actual values
477                self_value_s = ser_self[field_name] if of_serialized else self_value
478                other_value_s = ser_other[field_name] if of_serialized else other_value
479                # compare the values
480                if not array_safe_eq(self_value_s, other_value_s):
481                    diff_result[field_name] = {"self": self_value, "other": other_value}
482
483        # return the diff result
484        return diff_result
485
486    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
487        """update the instance from a nested dict, useful for configuration from command line args
488
489        # Parameters:
490            - `nested_dict : dict[str, Any]`
491                nested dict to update the instance with
492        """
493        for field in dataclasses.fields(self):  # type: ignore[arg-type]
494            field_name: str = field.name
495            self_value = getattr(self, field_name)
496
497            if field_name in nested_dict:
498                if isinstance(self_value, SerializableDataclass):
499                    self_value.update_from_nested_dict(nested_dict[field_name])
500                else:
501                    setattr(self, field_name, nested_dict[field_name])
502
503    def __copy__(self) -> "SerializableDataclass":
504        "deep copy by serializing and loading the instance to json"
505        return self.__class__.load(json.loads(json.dumps(self.serialize())))
506
507    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
508        "deep copy by serializing and loading the instance to json"
509        return self.__class__.load(json.loads(json.dumps(self.serialize())))

Base class for serializable dataclasses

only for linting and type checking, still need to call serializable_dataclass decorator

Usage:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize(self) -> dict[str, typing.Any]:
369    def serialize(self) -> dict[str, Any]:
370        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
371        raise NotImplementedError(
372            f"decorate {self.__class__ = } with `@serializable_dataclass`"
373        )

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls: Type[~T], data: Union[dict[str, Any], ~T]) -> ~T:
375    @classmethod
376    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
377        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
378        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
380    def validate_fields_types(
381        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
382    ) -> bool:
383        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
384        return SerializableDataclass__validate_fields_types(
385            self, on_typecheck_error=on_typecheck_error
386        )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

def validate_field_type( self, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
388    def validate_field_type(
389        self,
390        field: "SerializableField|str",
391        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
392    ) -> bool:
393        """given a dataclass, check the field matches the type hint"""
394        return SerializableDataclass__validate_field_type(
395            self, field, on_typecheck_error=on_typecheck_error
396        )

given a dataclass, check the field matches the type hint

def diff( self, other: SerializableDataclass, of_serialized: bool = False) -> dict[str, typing.Any]:
405    def diff(
406        self, other: "SerializableDataclass", of_serialized: bool = False
407    ) -> dict[str, Any]:
408        """get a rich and recursive diff between two instances of a serializable dataclass
409
410        ```python
411        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
412        {'b': {'self': 2, 'other': 3}}
413        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
414        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
415        ```
416
417        # Parameters:
418         - `other : SerializableDataclass`
419           other instance to compare against
420         - `of_serialized : bool`
421           if true, compare serialized data and not raw values
422           (defaults to `False`)
423
424        # Returns:
425         - `dict[str, Any]`
426
427
428        # Raises:
429         - `ValueError` : if the instances are not of the same type
430         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
431        """
432        # match types
433        if type(self) is not type(other):
434            raise ValueError(
435                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
436            )
437
438        # initialize the diff result
439        diff_result: dict = {}
440
441        # if they are the same, return the empty diff
442        if self == other:
443            return diff_result
444
445        # if we are working with serialized data, serialize the instances
446        if of_serialized:
447            ser_self: dict = self.serialize()
448            ser_other: dict = other.serialize()
449
450        # for each field in the class
451        for field in dataclasses.fields(self):  # type: ignore[arg-type]
452            # skip fields that are not for comparison
453            if not field.compare:
454                continue
455
456            # get values
457            field_name: str = field.name
458            self_value = getattr(self, field_name)
459            other_value = getattr(other, field_name)
460
461            # if the values are both serializable dataclasses, recurse
462            if isinstance(self_value, SerializableDataclass) and isinstance(
463                other_value, SerializableDataclass
464            ):
465                nested_diff: dict = self_value.diff(
466                    other_value, of_serialized=of_serialized
467                )
468                if nested_diff:
469                    diff_result[field_name] = nested_diff
470            # only support serializable dataclasses
471            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
472                other_value
473            ):
474                raise ValueError("Non-serializable dataclass is not supported")
475            else:
476                # get the values of either the serialized or the actual values
477                self_value_s = ser_self[field_name] if of_serialized else self_value
478                other_value_s = ser_other[field_name] if of_serialized else other_value
479                # compare the values
480                if not array_safe_eq(self_value_s, other_value_s):
481                    diff_result[field_name] = {"self": self_value, "other": other_value}
482
483        # return the diff result
484        return diff_result

get a rich and recursive diff between two instances of a serializable dataclass

>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}

Parameters:

  • other : SerializableDataclass other instance to compare against
  • of_serialized : bool if true, compare serialized data and not raw values (defaults to False)

Returns:

  • dict[str, Any]

Raises:

  • ValueError : if the instances are not of the same type
  • ValueError : if the instances are dataclasses.dataclass but not SerializableDataclass
def update_from_nested_dict(self, nested_dict: dict[str, typing.Any]):
486    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
487        """update the instance from a nested dict, useful for configuration from command line args
488
489        # Parameters:
490            - `nested_dict : dict[str, Any]`
491                nested dict to update the instance with
492        """
493        for field in dataclasses.fields(self):  # type: ignore[arg-type]
494            field_name: str = field.name
495            self_value = getattr(self, field_name)
496
497            if field_name in nested_dict:
498                if isinstance(self_value, SerializableDataclass):
499                    self_value.update_from_nested_dict(nested_dict[field_name])
500                else:
501                    setattr(self, field_name, nested_dict[field_name])

update the instance from a nested dict, useful for configuration from command line args

Parameters:

- `nested_dict : dict[str, Any]`
    nested dict to update the instance with