Source code for cosapp.utils.state_io

from __future__ import annotations
import sys
import copy
from typing import Any, Optional, Union, TYPE_CHECKING

if TYPE_CHECKING:
    from cosapp.base import System, Port


SystemState = dict[str, dict[str, dict]]


[docs] def get_state(system: System): """Export system state (inputs and outputs) into a dictionary for quick serialization, without type check. Returns: -------- Recursive dictionary of the kind: { 'ports: { port_name: { varname: value, ... }, ... }, 'children': { child_name: { 'ports': {...}, 'children: {...}, }, ... }, } """ def recursive_search(system: System, state: SystemState) -> None: state['ports'] = port_data = {} for port in system.ports(): data = { varname: copy.deepcopy(value) for varname, value in port.items() } if data: port_data[port.name] = data if system.children: state['children'] = child_data = {} for child in system.children.values(): child_data[child.name] = child_state = {} recursive_search(child, child_state) state: SystemState = {} recursive_search(system, state) return state
[docs] def set_state(system: System, state: SystemState) -> None: """Import system state (inputs and outputs) from a dictionary generated by `get_state`. """ for name, values in state['ports'].items(): port: Port = getattr(system, name) port.set_values(**copy.deepcopy(values)) child_data = state.get('children', {}) for name, child_state in child_data.items(): set_state(system[name], child_state)
if sys.version_info[1] < 11: def object__getstate__(obj: Any) -> Union[dict[str, Any], tuple[Optional[dict[str, Any]], dict[str, Any]]]: """Creates a state of an object. The state may take various forms depending on the object, see https://docs.python.org/3/library/pickle.html#object.__getstate__ for further details. Parameters ---------- obj: Any Object from which the state must be constructed Returns ------- Union[dict[str, Any], tuple[Optional[dict[str, Any]], dict[str, Any]]]: state """ def is_slot(cls: type, slot: str): magic_slot = slot.startswith("__") and slot.endswith("__") return not magic_slot and ( hasattr(cls, slot) or hasattr(cls, f"_{cls.__name__}{slot}") ) def prefix_private_slot(cls: type, slot: str): if slot.startswith("__") and not slot.endswith("__"): return f"_{cls.__name__}{slot}" return slot if hasattr(obj, "__slots__"): slots = [ prefix_private_slot(cls, slot) for cls in obj.__class__.__mro__ for slot in getattr(cls, "__slots__", []) if is_slot(cls, slot) ] slot_dict = {slot: getattr(obj, slot) for slot in slots} if slot_dict: return getattr(obj, "__dict__", None), slot_dict if hasattr(obj, "__dict__"): return obj.__dict__ else: object__getstate__ = object.__getstate__