import ast
from .. import ops
from cosapp.core.ir.visitors import NodeVisitor, NodeTransformer
class _InlineConstants(NodeTransformer):
"""Evaluate constant in inline operations."""
def __init__(self, compute_globals: dict, compute_locals: dict) -> None:
self._vars: dict = dict(compute_locals)
self._vars.update(compute_globals)
def __call__(self, node: ast.AST) -> ast.AST:
self.visit(node)
return node
def visit_Name(self, node: ast.Name) -> ast.Constant | ast.Name:
if node.id in self._vars and isinstance(self._vars[node.id], int):
return ast.Constant(value=self._vars[node.id])
return node
class _BinOpsCanonicalization(NodeTransformer):
"""Evaluate constant operations."""
def __call__(self, node: ast.AST) -> ast.AST:
self.visit(node)
return node
def visit_BinOp(self, node: ast.BinOp) -> ast.BinOp | ast.Constant:
self.generic_visit(node)
if not isinstance(node.op, (ast.Add, ast.Sub)):
return node
items = self._flatten(node)
const_sum = 0
others = []
for expr, sign in items:
if isinstance(expr, ast.Constant) and isinstance(expr.value, (int, float)):
const_sum += sign * expr.value
else:
others.append((expr, sign))
new_items = others[:]
if const_sum > 0:
new_items.append((ast.Constant(const_sum), +1))
elif const_sum < 0:
new_items.append((ast.Constant(-const_sum), -1))
if not new_items:
return ast.Constant(0)
expr, sign = new_items[0]
if sign == -1:
expr = ast.UnaryOp(op=ast.USub(), operand=expr)
node = expr
for expr, sign in new_items[1:]:
if sign == +1:
node = ast.BinOp(left=node, op=ast.Add(), right=expr)
else:
node = ast.BinOp(left=node, op=ast.Sub(), right=expr)
ast.fix_missing_locations(node)
return node
def _flatten(self, node: ast.AST, sign: int = 1) -> list[tuple[ast.AST, int]]:
if isinstance(node, ast.BinOp):
if isinstance(node.op, ast.Add):
return self._flatten(node.left, sign) + self._flatten(node.right, sign)
if isinstance(node.op, ast.Sub):
return self._flatten(node.left, sign) + self._flatten(node.right, -sign)
return [(node, sign)]
class _ConstantFolding(NodeTransformer):
"""Evaluate constant folding."""
def __call__(self, node: ast.AST) -> ast.AST:
self.visit(node)
return node
def visit_BinOp(self, node: ast.BinOp) -> ast.BinOp | ast.Constant:
left = self.visit(node.left)
right = self.visit(node.right)
if isinstance(left, ast.Constant) and isinstance(right, ast.Constant):
return ast.fix_missing_locations(
ast.Constant(value=eval(compile(ast.Expression(node), "ast", "eval")))
)
return node
class _ComputeLocals(NodeVisitor):
"""Collect and evaluate local values."""
def __init__(self) -> None:
self._compute_local_values: dict = {}
def __call__(self, node: ast.AST) -> dict:
self._compute_local_values = {}
if node.local_names != ("self",):
self._compute_global_values = node.globals
self.visit(node)
return self._compute_local_values
def visit_Assign(self, node: ast.Assign) -> None:
try:
module = ast.Module(body=[node], type_ignores=[])
cc = compile(ast.fix_missing_locations(module), "<ast>", "exec")
exec(cc, self._compute_global_values, self._compute_local_values)
except TypeError as e:
print(e)
except Exception:
pass
[docs]
def optimize_compute(node: ops.CausalModel) -> ops.CausalModel:
"""Collect and evaluate all upstream constant operations
of the compute part of a CoSApp IR model.
"""
# get local values
local_values = _ComputeLocals()(node.compute)
# intialize evaluation objects
inlining = _InlineConstants(node.compute.globals, local_values)
canonicalization = _BinOpsCanonicalization()
folding = _ConstantFolding()
# visit and evalute constant operations
inlining.visit(node.compute)
canonicalization.visit(node.compute)
folding.visit(node.compute)
return node