Source code for cosapp.recorders.recorder

"""Base class for recording data."""
from __future__ import annotations
import abc
import copy
import pandas
import warnings
from typing import Dict, Union, List, Optional, NamedTuple, Any, TYPE_CHECKING

from cosapp.core.signal import Signal
from cosapp.core.eval_str import EvalString
from cosapp.core.execution import ExecutionType
from cosapp.utils.helpers import is_numerical, check_arg
from cosapp.utils.find_variables import SearchPattern, make_wishlist, find_variables
from cosapp.utils.state_io import object__getstate__
if TYPE_CHECKING:
    from cosapp.core.module import Module


[docs] class SpecialColumns(NamedTuple): section: str status: str code: str reference: str
[docs] def is_cosapp_object(obj: Any) -> bool: """Returns `True` if `obj` is either a port, a system or a driver; `False` otherwise.""" from cosapp.base import Port, System, Driver return isinstance(obj, (Port, System, Driver))
[docs] def any_forbidden_object(collection: Any) -> bool: """Returns `True` if argument is either an iterator, a cosapp object, or a collection thereof; `False` otherwise.""" if isinstance(collection, str): return False try: iterator = iter(collection) except: pass else: if iterator is collection: # collection is an iterator; leave it untouched return True try: return any(map(is_cosapp_object, collection)) except TypeError: return is_cosapp_object(collection)
[docs] def no_forbidden_objects(dict_item: tuple) -> bool: """Filter function for dictionary items, to exclude entries that match `any_forbidden_object`.""" return not any_forbidden_object(dict_item[1])
[docs] class BaseRecorder(abc.ABC): """Abstract base class for recorders. Matching pattern are case sensitive and support the following special patterns: ======== ================================ Pattern Meaning ======== ================================ `*` matches everything `?` matches any single character `[seq]` matches any character in seq `[!seq]` matches any character not in seq ======== ================================ Excluding pattern are shadowing includes one; e.g. if `includes='*port_in.*' and `excludes='*.Pt'`, for a port having variables named `Tt` and `Pt`, only `Tt` will be recorded. Parameters ---------- includes : str or list of str, optional Variables matching these patterns will be included; default `'*'` (i.e. all variables). excludes : str or list of str or None, optional Variables matching these patterns will be excluded; default `None` (i.e. nothing is excluded). numerical_only : bool, optional Keep only numerical variables (i.e. number or numerical vector); default False. section : str, optional Current section name; default `''`. precision : int, optional Precision digits when writing floating point number; default 9 (i.e. 10 figures will be written). hold : bool, optional Append the new data or not; default `False`. raw_output : bool, optional Raw output; default `False`. Signals ------- state_recorded : Signal(time_ref: Union[float, str], status: str, error_code: str) Signal emitted after `record_state` execution cleared : Signal() Signal emitted after `clear` execution """ SPECIALS = SpecialColumns( section="Section", status="Status", code="Error code", reference="Reference", ) def __init__(self, includes: SearchPattern = "*", excludes: Optional[SearchPattern] = None, # metadata: Optional[SearchPattern] = None, # TODO? numerical_only = False, section = "", precision = 9, hold = False, raw_output = False, ): check_arg(raw_output, 'raw_output', bool) check_arg(numerical_only, 'numerical_only', bool) self.__includes = make_wishlist(includes, "includes") # type: List[str] self.__excludes = make_wishlist(excludes, "excludes") # type: List[str] self._numerical_only = numerical_only # type: bool # self.__metadata = metadata # type: List[str] self.__section = "" # type: str self.section = section self.__hold = False # type: bool self.hold = hold self.__precision = 9 # type: int self.precision = precision self._raw_output = raw_output # type: bool self.__variables = None # type: Optional[List[str]] self.__expressions = None # type: EvalString self._watch_object = None # type: Optional[cosapp.core.module.Module] self._owner = None # type: Optional[str] self.paused = False # Signal self.state_recorded = Signal( args=["time_ref", "status", "error_code"], name="cosapp.recorders.recorder.BaseRecorder.state_recorded", ) self.cleared = Signal(name="cosapp.recorders.recorder.BaseRecorder.cleared") def __getstate__(self) -> Dict[str, Any]: """Creates a state of the object. The state type depend on the object, see https://docs.python.org/3/library/pickle.html#object.__getstate__ for further details. Returns ------- Dict[str, Any]: state """ state = object__getstate__(self).copy() return state def __json__(self) -> Dict[str, Any]: """Creates a JSONable dictionary representation of the object. Returns ------- Dict[str, Any] The dictionary """ qualname = f"{self.__module__}.{self.__class__.__qualname__}" state = self.__getstate__() state.pop("_watch_object") state.pop("_owner") state.pop("state_recorded") state.pop("cleared") return { "__class__": qualname, **state, }
[docs] @classmethod def from_dict(cls, state): rec = cls() for key, val in state.items(): setattr(rec, key, val) return rec
[docs] @classmethod def extend( cls, recorder, includes: SearchPattern = [], excludes: SearchPattern = [] ) -> BaseRecorder: """ Factory returning a new recorder, with similar attributes as `recorder`, but extended `includes` and `excludes` fields. Parameters ---------- includes : str, list of str, or None, optional Variables patterns extending `recorder` inclusion patterns; default `[]` (i.e. no change). excludes : str, list of str, or None, optional Variables patterns extending `recorder` exclusion patterns; default `[]` (i.e. no change). """ new = cls( recorder.includes + make_wishlist(includes, "includes"), recorder.excludes + make_wishlist(excludes, "excludes"), recorder._numerical_only, recorder.section, recorder.precision, recorder.hold, recorder._raw_output, ) new.watched_object = recorder.watched_object return new
@property def watched_object(self) -> Optional[Module]: """Module : The object from which data are read.""" return self._watch_object @watched_object.setter def watched_object(self, module: Module): from cosapp.core.module import Module if not (module is None or isinstance(module, Module)): typename = type(module).__qualname__ raise TypeError( f"Recorder must be attached to a System or a Driver; got {typename}." ) is_new = module is not self._watch_object if self._watch_object is None or is_new: self.__variables = None self.__expressions = None self._watch_object = module @property @abc.abstractmethod def _raw_data(self) -> List[List[Any]]: """Returns a raw/unformated version of records. Returns ------- List[List[Any]] Records of `watched_object` for variables given by method `field_names()` """ pass @property def data(self) -> pandas.DataFrame: """ pandas.DataFrame: DataFrame containing the results. Deprecated property; use `export_data()` instead. """ warnings.warn( "Deprecated property; use export_data() instead", DeprecationWarning ) return self.export_data()
[docs] @abc.abstractmethod def export_data(self) -> pandas.DataFrame: """Export recorded results into a pandas.DataFrame object.""" pass
@property def includes(self) -> List[str]: """str or list of str : Variables matching these patterns will be included.""" return self.__includes @property def excludes(self) -> List[str]: """str or list of str: Variables matching these patterns will be excluded.""" return self.__excludes @property def section(self) -> str: """str : Current section name.""" return self.__section @section.setter def section(self, section: str): check_arg(section, 'section', str) self.__section = section @property def hold(self) -> bool: """bool : Append the new data or not.""" return self.__hold @hold.setter def hold(self, value: bool): check_arg(value, 'hold', bool) self.__hold = value @property def precision(self) -> int: """int : Precision digits when writing floating point number.""" return self.__precision @precision.setter def precision(self, value: int): check_arg(value, 'precision', int, lambda n: n > 0) self.__precision = value
[docs] def field_names(self) -> List[str]: """Returns list of requested variables and expressions. Returns ------- List[str] Variable names matching the includes/excludes patterns of the user in the watched object. .. note:: Inward and outward variables will appear without the prefix `inwards.` or `outwards.`. """ if self.__variables is None and self.watched_object is not None: self.__update_varlist() return self.__variables or list()
def __contains__(self, field: str) -> bool: return field in self.field_names() def __update_varlist(self) -> None: """Update the list of fields to be recorded""" includes = [] evaluables = {} context = self.watched_object for expression in set(self.__includes): try: evaluable = EvalString(expression, context) evaluable.eval() except: includes.append(expression) else: if expression in context: # Keep expression in search list, # in case it matches an exclusion pattern includes.append(expression) else: evaluables[expression] = evaluable.eval() if self._numerical_only: criterion = is_numerical else: criterion = lambda x: True variables = find_variables( context, includes, self.__excludes, advanced_filter=criterion, include_const=True, ) variables.update(evaluables) # Filter out attributes containing forbidden objects # such as non-copyable CoSApp objects or iterators variables = dict( filter(no_forbidden_objects, variables.items()) ) self.__variables = variables = sorted(variables) self.__expressions = EvalString(", ".join(variables).join("[]"), context) def _get_units(self, varnames: Optional[List[str]] = None) -> List[str]: """Generate the list of units associated with the requested variables. Parameters ---------- vars : list of str, optional List of the variables to get the units; if not provided, takes field_names Returns ------- List[str] Variable units for the variables matching the includes/excludes patterns of the user. """ if varnames is None: varnames = self.field_names() units = list() if self.watched_object is not None: name2variable = self.watched_object.name2variable for name in varnames: try: ref = name2variable[name] details = ref.mapping.get_details(ref.key) unit = details.unit except: unit = None units.append(unit or "-") return units
[docs] def start(self): """Initialize recording support.""" if self.watched_object is None: raise RuntimeError("A recorder should be watching a Driver.")
[docs] def record_state(self, time_ref: Union[float, str], status="", error_code="0") -> None: """Record the watched object at the provided status. Parameters ---------- time_ref : float or str Current simulation time (float) or point reference (str) status : str, optional Status of the simulation; default `''`. error_code : str, optional Error code; default `'0'`. """ if self.paused: return line = [self.section, status, error_code, str(time_ref)] + self.formatted_data() self._record(line) self.state_recorded.emit(time_ref=time_ref, status=status, error_code=error_code)
[docs] @abc.abstractmethod def formatted_data(self) -> List[Any]: """Returns collected data from watched object as a formatted list.""" pass
[docs] def collected_data(self) -> List[Any]: """Collects and returns recorded data from watched object as a list.""" if self.__expressions is None: self.__update_varlist() return self.__expressions.eval()
@abc.abstractmethod def _record(self, line: List[Any]) -> None: """ Internally record data collected into list `line`. Required by `record_state` method. """ pass @abc.abstractmethod def _batch_record(self, lines: List[List[Any]]) -> None: """Records multiple lines at a time. Internal API allowing efficient concatenation of recorders. """ pass @abc.abstractmethod def _enable_parallel_execution(self, exec_type: ExecutionType, chunk_id: int) -> None: """Enables the use of this `Recorder` in parallel execution. This method must perform the necessary changes to allow parallel execution in a multithreading or multiprocessing context. Parameters ---------- exec_type : ExecutionType Type of parallel execution chunk_id : int Identifier of the chunk to be handled by this recorder """ pass @abc.abstractmethod def _disable_parallel_execution(self, exec_type: ExecutionType, chunk_id: int) -> None: """Disables the use of this `Recorder` in parallel execution. This method rollbacks the changes made to the `Recorder` to handle parallel execution. Parameters ---------- exec_type : ExecutionType Type of parallel execution chunk_id : int Identifier of the chunk to be handled by this recorder """ pass def _set_parallel_execution(self, chunck_id: int) -> None: """"""
[docs] def clear(self): """Clear all previously stored data.""" self.cleared.emit()
[docs] @abc.abstractmethod def exit(self): """Close recording session.""" pass
[docs] def restore(self, index: int): """Restore the watch object state from the recorded data. Parameters ---------- index : int Index of the record as iloc in the Pandas DataFrame """ check_arg(index, 'index', int) data = self._raw_data if index < 0 or index > len(data) - 1: raise IndexError(f"Index {index} does not exist in this Recorder") inputs = find_variables( self._watch_object, includes=["*"], excludes=[], outputs=False ) recording = data[index][4:] for idx, var in [ (idx, var) for idx, var in enumerate(self.field_names()) if var in inputs ]: try: self._watch_object[var] = copy.deepcopy(recording[idx]) except KeyError: pass
[docs] def check(self) -> None: try: self.formatted_data() except Exception as ex: warnings.warn( f"Captured exception {ex!r} while trying to collect data in {self.watched_object}" "; recorder may fail.", RuntimeWarning )