Source code for cosapp.core.ir.ops

import ast
import sys
from dataclasses import dataclass, field
from typing import Any, Callable, Iterator, Optional


if sys.version_info >= (3, 10):
    from dataclasses import KW_ONLY

    @dataclass()
    class Node:
        """Base class for AST nodes."""

        _fields = ["desc"]
        _: KW_ONLY
        desc: Optional[str] = None
else:

    class KW_ONLY:
        pass

    class MetaC(type):
        def __new__(cls, name, bases, dct):
            dct["__annotations__"].pop("_", None)

            if name == "Node":
                dct.pop("desc")
                dct["__annotations__"].pop("desc", None)
            else:
                dct["desc"] = None
                dct["__annotations__"]["desc"] = Optional[str]

            return super().__new__(
                cls, name, bases, {key: val for key, val in dct.items() if key != "_"}
            )

[docs] @dataclass() class Node(metaclass=MetaC): """Base class for AST nodes.""" _fields = ["desc"] _: KW_ONLY desc: Optional[str] = None
[docs] @dataclass() class Variable(Node): """Variable AST node.""" _fields = [ "name", "unit", "initial_value", "min_value", "max_value", ] name: str _: KW_ONLY dtype: Optional[type | list[type]] = None type_value: type | list[type] = float shape: Optional[tuple[int, int]] = None unit: Optional[str] = None initial_value: Optional[Any] = None min_value: Optional[ast.Constant] = None max_value: Optional[ast.Constant] = None
[docs] @dataclass() class Unknown(Variable): """Unknown AST node.""" _fields = [ "name", "min_value", "max_value", "max_abs_step", "max_rel_step", ] max_abs_step: Optional[ast.Constant] = None max_rel_step: Optional[ast.Constant] = None
[docs] @dataclass() class Connection(Node): """Connection AST node.""" _fields = ["source", "sink", "mapping"] _: KW_ONLY source: str sink: str mapping: Optional[dict[str, str]] = None
[docs] @dataclass() class Relationship(Node): """Relationship AST node.""" _fields = ["source", "sink", "func"] _: KW_ONLY source: str sink: str func: Callable
[docs] @dataclass() class Signature(Node): """Function signature AST node.""" _fields = ["ret_type", "args"] _: KW_ONLY return_type: str args: list[Variable] = field(default_factory=list)
[docs] @dataclass() class FunctionPrototype(Node): """Function prototype AST node.""" _fields = ["name", "signature"] name: str _: KW_ONLY signature: Signature long_desc: Optional[str] = None
[docs] @dataclass() class FunctionDef(FunctionPrototype): """Function definition AST node.""" _fields = ["name", "signature", "desc", "body", "variables"] variables: list[Variable] = field(default_factory=list) body: list[ast.stmt] = field(default_factory=list)
[docs] @dataclass class PortDef(Node): """Port definition AST node.""" _fields = ["name", "variables"] name: str _: KW_ONLY variables: list[Variable] = field(default_factory=list)
[docs] @dataclass class Port(Node): """Port instance AST node.""" _fields = ["name", "ty"] name: str _: KW_ONLY ty: PortDef
[docs] @dataclass class Relation(Node): """Relation AST node.""" _fields = ["lhs", "rhs", "operator"] _: KW_ONLY equation: str lhs: ast.expr rhs: ast.expr operator: ast.Eq | ast.Lt | ast.LtE | ast.Gt | ast.GtE = ast.Eq is_transient: bool = False def __iter__(self) -> Iterator: """Iterate over a relation as a tuple.""" yield self.lhs yield self.rhs def __repr__(self) -> str: return " == ".join([ast.unparse(self.lhs), ast.unparse(self.rhs)]) def __eq__(self, other) -> bool: return hash(self) == hash(other) def __hash__(self) -> int: # FIXME: handle all operators return hash(" == ".join([ast.unparse(self.lhs), ast.unparse(self.rhs)]))
[docs] @dataclass class NestedModel(Node): """NestedModel instance AST node.""" _fields = ["name", "ty", "pullings"] name: str _: KW_ONLY ty: "AbstractModel" pullings: dict[str, str] = field(default_factory=dict) def __getattr__(self, name: str) -> tuple[list[Port | Variable], str]: elements = self.ty.topology.outputs + self.ty.topology.inputs for elt in elements: if name == elt.name: return elt, self.name raise AttributeError(f"'NestedModel' object has no attribute '{name}'") @property def outwards(self) -> tuple[list[Variable | Port], str]: return self.ty.topology.outputs, self.name @property def inwards(self) -> tuple[list[Variable | Port], str]: return self.ty.topology.inputs, self.name
[docs] @dataclass class MathematicalProblem(Node): """Mathematical problem AST node.""" _fields = ["constraints", "unknowns"] _: KW_ONLY constraints: list[Relation] = field(default_factory=list) unknowns: list[Unknown] = field(default_factory=list)
[docs] @dataclass class ModelTopology(Node): """Model topology AST node.""" _fields = ["children", "connections", "inputs", "outputs", "relationships", "mathematical_problem"] children: dict[str, NestedModel] = field(default_factory=dict) connections: list[Connection] = field(default_factory=list) inputs: list[Port | Variable] = field(default_factory=list) outputs: list[Port | Variable] = field(default_factory=list) mathematical_problem: MathematicalProblem = field( default_factory=MathematicalProblem ) relationships: list[Relationship] = field(default_factory=list)
[docs] @dataclass class AbstractModel(Node): """Abstract model AST node.""" name: str bases: list["AbstractModel"] = field(default_factory=list) topology: ModelTopology = field(default_factory=ModelTopology) @property def outwards(self) -> tuple[list[Port | Variable], str]: return (self.topology.outputs, self.name) @property def inwards(self) -> tuple[list[Port | Variable], str]: return (self.topology.inputs, self.name)
[docs] @dataclass class ComputeFunction(Node): """Compute function of a causal model AST node.""" _fields = ["desc", "locals", "body"] _: KW_ONLY local_names: list[str] = field(default_factory=list) global_names: list[str] = field(default_factory=list) globals: dict[str, Any] = field(default_factory=dict) body: list[ast.stmt] = field(default_factory=list)
[docs] @dataclass class CausalModel(AbstractModel): """Causal model AST node.""" _fields = ["name", "desc", "topology", "compute"] _: KW_ONLY compute: ComputeFunction = field(default_factory=ComputeFunction) exec_order: list[str] = field(default_factory=list)
Model = CausalModel
[docs] @dataclass class Module(Node): """Module AST node.""" _fields = ["body"] _: KW_ONLY body: list[Model | ast.stmt] = field(default_factory=list)