import ast
import sys
from dataclasses import dataclass, field
from typing import Any, Callable, Iterator, Optional
if sys.version_info >= (3, 10):
from dataclasses import KW_ONLY
@dataclass()
class Node:
"""Base class for AST nodes."""
_fields = ["desc"]
_: KW_ONLY
desc: Optional[str] = None
else:
class KW_ONLY:
pass
class MetaC(type):
def __new__(cls, name, bases, dct):
dct["__annotations__"].pop("_", None)
if name == "Node":
dct.pop("desc")
dct["__annotations__"].pop("desc", None)
else:
dct["desc"] = None
dct["__annotations__"]["desc"] = Optional[str]
return super().__new__(
cls, name, bases, {key: val for key, val in dct.items() if key != "_"}
)
[docs]
@dataclass()
class Node(metaclass=MetaC):
"""Base class for AST nodes."""
_fields = ["desc"]
_: KW_ONLY
desc: Optional[str] = None
[docs]
@dataclass()
class Variable(Node):
"""Variable AST node."""
_fields = [
"name",
"unit",
"initial_value",
"min_value",
"max_value",
]
name: str
_: KW_ONLY
dtype: Optional[type | list[type]] = None
type_value: type | list[type] = float
shape: Optional[tuple[int, int]] = None
unit: Optional[str] = None
initial_value: Optional[Any] = None
min_value: Optional[ast.Constant] = None
max_value: Optional[ast.Constant] = None
[docs]
@dataclass()
class Unknown(Variable):
"""Unknown AST node."""
_fields = [
"name",
"min_value",
"max_value",
"max_abs_step",
"max_rel_step",
]
max_abs_step: Optional[ast.Constant] = None
max_rel_step: Optional[ast.Constant] = None
[docs]
@dataclass()
class Connection(Node):
"""Connection AST node."""
_fields = ["source", "sink", "mapping"]
_: KW_ONLY
source: str
sink: str
mapping: Optional[dict[str, str]] = None
[docs]
@dataclass()
class Relationship(Node):
"""Relationship AST node."""
_fields = ["source", "sink", "func"]
_: KW_ONLY
source: str
sink: str
func: Callable
[docs]
@dataclass()
class Signature(Node):
"""Function signature AST node."""
_fields = ["ret_type", "args"]
_: KW_ONLY
return_type: str
args: list[Variable] = field(default_factory=list)
[docs]
@dataclass()
class FunctionPrototype(Node):
"""Function prototype AST node."""
_fields = ["name", "signature"]
name: str
_: KW_ONLY
signature: Signature
long_desc: Optional[str] = None
[docs]
@dataclass()
class FunctionDef(FunctionPrototype):
"""Function definition AST node."""
_fields = ["name", "signature", "desc", "body", "variables"]
variables: list[Variable] = field(default_factory=list)
body: list[ast.stmt] = field(default_factory=list)
[docs]
@dataclass
class PortDef(Node):
"""Port definition AST node."""
_fields = ["name", "variables"]
name: str
_: KW_ONLY
variables: list[Variable] = field(default_factory=list)
[docs]
@dataclass
class Port(Node):
"""Port instance AST node."""
_fields = ["name", "ty"]
name: str
_: KW_ONLY
ty: PortDef
[docs]
@dataclass
class Relation(Node):
"""Relation AST node."""
_fields = ["lhs", "rhs", "operator"]
_: KW_ONLY
equation: str
lhs: ast.expr
rhs: ast.expr
operator: ast.Eq | ast.Lt | ast.LtE | ast.Gt | ast.GtE = ast.Eq
is_transient: bool = False
def __iter__(self) -> Iterator:
"""Iterate over a relation as a tuple."""
yield self.lhs
yield self.rhs
def __repr__(self) -> str:
return " == ".join([ast.unparse(self.lhs), ast.unparse(self.rhs)])
def __eq__(self, other) -> bool:
return hash(self) == hash(other)
def __hash__(self) -> int:
# FIXME: handle all operators
return hash(" == ".join([ast.unparse(self.lhs), ast.unparse(self.rhs)]))
[docs]
@dataclass
class NestedModel(Node):
"""NestedModel instance AST node."""
_fields = ["name", "ty", "pullings"]
name: str
_: KW_ONLY
ty: "AbstractModel"
pullings: dict[str, str] = field(default_factory=dict)
def __getattr__(self, name: str) -> tuple[list[Port | Variable], str]:
elements = self.ty.topology.outputs + self.ty.topology.inputs
for elt in elements:
if name == elt.name:
return elt, self.name
raise AttributeError(f"'NestedModel' object has no attribute '{name}'")
@property
def outwards(self) -> tuple[list[Variable | Port], str]:
return self.ty.topology.outputs, self.name
@property
def inwards(self) -> tuple[list[Variable | Port], str]:
return self.ty.topology.inputs, self.name
[docs]
@dataclass
class MathematicalProblem(Node):
"""Mathematical problem AST node."""
_fields = ["constraints", "unknowns"]
_: KW_ONLY
constraints: list[Relation] = field(default_factory=list)
unknowns: list[Unknown] = field(default_factory=list)
[docs]
@dataclass
class ModelTopology(Node):
"""Model topology AST node."""
_fields = ["children", "connections", "inputs", "outputs", "relationships", "mathematical_problem"]
children: dict[str, NestedModel] = field(default_factory=dict)
connections: list[Connection] = field(default_factory=list)
inputs: list[Port | Variable] = field(default_factory=list)
outputs: list[Port | Variable] = field(default_factory=list)
mathematical_problem: MathematicalProblem = field(
default_factory=MathematicalProblem
)
relationships: list[Relationship] = field(default_factory=list)
[docs]
@dataclass
class AbstractModel(Node):
"""Abstract model AST node."""
name: str
bases: list["AbstractModel"] = field(default_factory=list)
topology: ModelTopology = field(default_factory=ModelTopology)
@property
def outwards(self) -> tuple[list[Port | Variable], str]:
return (self.topology.outputs, self.name)
@property
def inwards(self) -> tuple[list[Port | Variable], str]:
return (self.topology.inputs, self.name)
[docs]
@dataclass
class ComputeFunction(Node):
"""Compute function of a causal model AST node."""
_fields = ["desc", "locals", "body"]
_: KW_ONLY
local_names: list[str] = field(default_factory=list)
global_names: list[str] = field(default_factory=list)
globals: dict[str, Any] = field(default_factory=dict)
body: list[ast.stmt] = field(default_factory=list)
[docs]
@dataclass
class CausalModel(AbstractModel):
"""Causal model AST node."""
_fields = ["name", "desc", "topology", "compute"]
_: KW_ONLY
compute: ComputeFunction = field(default_factory=ComputeFunction)
exec_order: list[str] = field(default_factory=list)
Model = CausalModel
[docs]
@dataclass
class Module(Node):
"""Module AST node."""
_fields = ["body"]
_: KW_ONLY
body: list[Model | ast.stmt] = field(default_factory=list)