Source code for cosapp.core.ir.passes.compute_optimizations

import ast
from .. import ops
from cosapp.core.ir.visitors import NodeVisitor, NodeTransformer


class _InlineConstants(NodeTransformer):
    """Evaluate constant in inline operations."""

    def __init__(self, compute_globals: dict, compute_locals: dict) -> None:
        self._vars: dict = dict(compute_locals)
        self._vars.update(compute_globals)

    def __call__(self, node: ast.AST) -> ast.AST:
        self.visit(node)
        return node

    def visit_Name(self, node: ast.Name) -> ast.Constant | ast.Name:
        if node.id in self._vars and isinstance(self._vars[node.id], int):
            return ast.Constant(value=self._vars[node.id])

        return node


class _BinOpsCanonicalization(NodeTransformer):
    """Evaluate constant operations."""

    def __call__(self, node: ast.AST) -> ast.AST:
        self.visit(node)
        return node

    def visit_BinOp(self, node: ast.BinOp) -> ast.BinOp | ast.Constant:
        self.generic_visit(node)

        if not isinstance(node.op, (ast.Add, ast.Sub)):
            return node

        items = self._flatten(node)

        const_sum = 0
        others = []
        for expr, sign in items:
            if isinstance(expr, ast.Constant) and isinstance(expr.value, (int, float)):
                const_sum += sign * expr.value
            else:
                others.append((expr, sign))

        new_items = others[:]
        if const_sum > 0:
            new_items.append((ast.Constant(const_sum), +1))
        elif const_sum < 0:
            new_items.append((ast.Constant(-const_sum), -1))

        if not new_items:
            return ast.Constant(0)

        expr, sign = new_items[0]
        if sign == -1:
            expr = ast.UnaryOp(op=ast.USub(), operand=expr)

        node = expr
        for expr, sign in new_items[1:]:
            if sign == +1:
                node = ast.BinOp(left=node, op=ast.Add(), right=expr)
            else:
                node = ast.BinOp(left=node, op=ast.Sub(), right=expr)

        ast.fix_missing_locations(node)

        return node

    def _flatten(self, node: ast.AST, sign: int = 1) -> list[tuple[ast.AST, int]]:
        if isinstance(node, ast.BinOp):
            if isinstance(node.op, ast.Add):
                return self._flatten(node.left, sign) + self._flatten(node.right, sign)

            if isinstance(node.op, ast.Sub):
                return self._flatten(node.left, sign) + self._flatten(node.right, -sign)

        return [(node, sign)]


class _ConstantFolding(NodeTransformer):
    """Evaluate constant folding."""

    def __call__(self, node: ast.AST) -> ast.AST:
        self.visit(node)
        return node

    def visit_BinOp(self, node: ast.BinOp) -> ast.BinOp | ast.Constant:
        left = self.visit(node.left)
        right = self.visit(node.right)

        if isinstance(left, ast.Constant) and isinstance(right, ast.Constant):
            return ast.fix_missing_locations(
                ast.Constant(value=eval(compile(ast.Expression(node), "ast", "eval")))
            )

        return node


class _ComputeLocals(NodeVisitor):
    """Collect and evaluate local values."""

    def __init__(self) -> None:
        self._compute_local_values: dict = {}

    def __call__(self, node: ast.AST) -> dict:
        self._compute_local_values = {}

        if node.local_names != ("self",):
            self._compute_global_values = node.globals
            self.visit(node)

        return self._compute_local_values

    def visit_Assign(self, node: ast.Assign) -> None:
        try:
            module = ast.Module(body=[node], type_ignores=[])
            cc = compile(ast.fix_missing_locations(module), "<ast>", "exec")
            exec(cc, self._compute_global_values, self._compute_local_values)
        except TypeError as e:
            print(e)
        except Exception:
            pass


[docs] def optimize_compute(node: ops.CausalModel) -> ops.CausalModel: """Collect and evaluate all upstream constant operations of the compute part of a CoSApp IR model. """ # get local values local_values = _ComputeLocals()(node.compute) # intialize evaluation objects inlining = _InlineConstants(node.compute.globals, local_values) canonicalization = _BinOpsCanonicalization() folding = _ConstantFolding() # visit and evalute constant operations inlining.visit(node.compute) canonicalization.visit(node.compute) folding.visit(node.compute) return node