Coverage for muutils\json_serialize\util.py: 44%
112 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-12 20:43 -0700
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-12 20:43 -0700
1"""utilities for json_serialize"""
3from __future__ import annotations
5import dataclasses
6import functools
7import inspect
8import sys
9import typing
10import warnings
11from typing import Any, Callable, Iterable, Union
13_NUMPY_WORKING: bool
14try:
15 _NUMPY_WORKING = True
16except ImportError:
17 warnings.warn("numpy not found, cannot serialize numpy arrays!")
18 _NUMPY_WORKING = False
21JSONitem = Union[bool, int, float, str, list, typing.Dict[str, Any], None]
22JSONdict = typing.Dict[str, JSONitem]
23Hashableitem = Union[bool, int, float, str, tuple]
25# or if python version <3.9
26if typing.TYPE_CHECKING or sys.version_info < (3, 9):
27 MonoTuple = typing.Sequence
28else:
30 class MonoTuple:
31 """tuple type hint, but for a tuple of any length with all the same type"""
33 __slots__ = ()
35 def __new__(cls, *args, **kwargs):
36 raise TypeError("Type MonoTuple cannot be instantiated.")
38 def __init_subclass__(cls, *args, **kwargs):
39 raise TypeError(f"Cannot subclass {cls.__module__}")
41 # idk why mypy thinks there is no such function in typing
42 @typing._tp_cache # type: ignore
43 def __class_getitem__(cls, params):
44 if getattr(params, "__origin__", None) == typing.Union:
45 return typing.GenericAlias(tuple, (params, Ellipsis))
46 elif isinstance(params, type):
47 typing.GenericAlias(tuple, (params, Ellipsis))
48 # test if has len and is iterable
49 elif isinstance(params, Iterable):
50 if len(params) == 0:
51 return tuple
52 elif len(params) == 1:
53 return typing.GenericAlias(tuple, (params[0], Ellipsis))
54 else:
55 raise TypeError(f"MonoTuple expects 1 type argument, got {params = }")
58class UniversalContainer:
59 """contains everything -- `x in UniversalContainer()` is always True"""
61 def __contains__(self, x: Any) -> bool:
62 return True
65def isinstance_namedtuple(x: Any) -> bool:
66 """checks if `x` is a `namedtuple`
68 credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
69 """
70 t: type = type(x)
71 b: tuple = t.__bases__
72 if len(b) != 1 or (b[0] is not tuple):
73 return False
74 f: Any = getattr(t, "_fields", None)
75 if not isinstance(f, tuple):
76 return False
77 return all(isinstance(n, str) for n in f)
80def try_catch(func: Callable):
81 """wraps the function to catch exceptions, returns serialized error message on exception
83 returned func will return normal result on success, or error message on exception
84 """
86 @functools.wraps(func)
87 def newfunc(*args, **kwargs):
88 try:
89 return func(*args, **kwargs)
90 except Exception as e:
91 return f"{e.__class__.__name__}: {e}"
93 return newfunc
96def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem:
97 if isinstance(obj, typing.Mapping):
98 return tuple((k, _recursive_hashify(v)) for k, v in obj.items())
99 elif isinstance(obj, (tuple, list, Iterable)):
100 return tuple(_recursive_hashify(v) for v in obj)
101 elif isinstance(obj, (bool, int, float, str)):
102 return obj
103 else:
104 if force:
105 return str(obj)
106 else:
107 raise ValueError(f"cannot hashify:\n{obj}")
110class SerializationException(Exception):
111 pass
114def string_as_lines(s: str | None) -> list[str]:
115 """for easier reading of long strings in json, split up by newlines
117 sort of like how jupyter notebooks do it
118 """
119 if s is None:
120 return list()
121 else:
122 return s.splitlines(keepends=False)
125def safe_getsource(func) -> list[str]:
126 try:
127 return string_as_lines(inspect.getsource(func))
128 except Exception as e:
129 return string_as_lines(f"Error: Unable to retrieve source code:\n{e}")
132# credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises
133def array_safe_eq(a: Any, b: Any) -> bool:
134 """check if two objects are equal, account for if numpy arrays or torch tensors"""
135 if a is b:
136 return True
138 if type(a) is not type(b):
139 return False
141 if (
142 str(type(a)) == "<class 'numpy.ndarray'>"
143 and str(type(b)) == "<class 'numpy.ndarray'>"
144 ) or (
145 str(type(a)) == "<class 'torch.Tensor'>"
146 and str(type(b)) == "<class 'torch.Tensor'>"
147 ):
148 return (a == b).all()
150 if (
151 str(type(a)) == "<class 'pandas.core.frame.DataFrame'>"
152 and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>"
153 ):
154 return a.equals(b)
156 if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence):
157 if len(a) == 0 and len(b) == 0:
158 return True
159 return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b))
161 if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)):
162 return len(a) == len(b) and all(
163 array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2])
164 for k1, k2 in zip(a.keys(), b.keys())
165 )
167 try:
168 return bool(a == b)
169 except (TypeError, ValueError) as e:
170 warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}")
171 return NotImplemented # type: ignore[return-value]
174def dc_eq(
175 dc1,
176 dc2,
177 except_when_class_mismatch: bool = False,
178 false_when_class_mismatch: bool = True,
179 except_when_field_mismatch: bool = False,
180) -> bool:
181 """
182 checks if two dataclasses which (might) hold numpy arrays are equal
184 # Parameters:
186 - `dc1`: the first dataclass
187 - `dc2`: the second dataclass
188 - `except_when_class_mismatch: bool`
189 if `True`, will throw `TypeError` if the classes are different.
190 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False`
191 (default: `False`)
192 - `false_when_class_mismatch: bool`
193 only relevant if `except_when_class_mismatch` is `False`.
194 if `True`, will return `False` if the classes are different.
195 if `False`, will attempt to compare the fields.
196 - `except_when_field_mismatch: bool`
197 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`.
198 if `True`, will throw `TypeError` if the fields are different.
199 (default: `True`)
201 # Returns:
202 - `bool`: True if the dataclasses are equal, False otherwise
204 # Raises:
205 - `TypeError`: if the dataclasses are of different classes
206 - `AttributeError`: if the dataclasses have different fields
208 # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?
209 ```
210 [START]
211 ▼
212 ┌───────────┐ ┌─────────┐
213 │dc1 is dc2?├─►│ classes │
214 └──┬────────┘No│ match? │
215 ──── │ ├─────────┤
216 (True)◄──┘Yes │No │Yes
217 ──── ▼ ▼
218 ┌────────────────┐ ┌────────────┐
219 │ except when │ │ fields keys│
220 │ class mismatch?│ │ match? │
221 ├───────────┬────┘ ├───────┬────┘
222 │Yes │No │No │Yes
223 ▼ ▼ ▼ ▼
224 ─────────── ┌──────────┐ ┌────────┐
225 { raise } │ except │ │ field │
226 { TypeError } │ when │ │ values │
227 ─────────── │ field │ │ match? │
228 │ mismatch?│ ├────┬───┘
229 ├───────┬──┘ │ │Yes
230 │Yes │No │No ▼
231 ▼ ▼ │ ────
232 ─────────────── ───── │ (True)
233 { raise } (False)◄┘ ────
234 { AttributeError} ─────
235 ───────────────
236 ```
238 """
239 if dc1 is dc2:
240 return True
242 if dc1.__class__ is not dc2.__class__:
243 if except_when_class_mismatch:
244 # if the classes don't match, raise an error
245 raise TypeError(
246 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`"
247 )
248 if except_when_field_mismatch:
249 dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)])
250 dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)])
251 fields_match: bool = set(dc1_fields) == set(dc2_fields)
252 if not fields_match:
253 # if the fields match, keep going
254 raise AttributeError(
255 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`"
256 )
257 return False
259 return all(
260 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name))
261 for fld in dataclasses.fields(dc1)
262 if fld.compare
263 )