from dataclasses import dataclass
from dataclasses import field
from typing import Dict
from pymoose.computation import operations as ops
from pymoose.computation import placements as plc
[docs]@dataclass
class Computation:
operations: Dict[str, ops.Operation] = field(default_factory=dict)
placements: Dict[str, plc.Placement] = field(default_factory=dict)
[docs] def find_destinations(self, op):
destination_ops = []
for candidate_op in self.operations.values():
if op.name in candidate_op.inputs.values():
destination_ops += [candidate_op]
return destination_ops
[docs] def find_sources(self, op):
source_ops = []
for input_op_name in op.inputs.values():
op = self.operation(input_op_name)
source_ops += [op]
return source_ops
[docs] def add(self, component):
if isinstance(component, ops.Operation):
return self.add_operation(component)
if isinstance(component, plc.Placement):
return self.add_placement(component)
raise NotImplementedError(f"{component}")
[docs] def maybe_add(self, component):
if isinstance(component, ops.Operation):
return self.maybe_add_operation(component)
if isinstance(component, plc.Placement):
return self.maybe_add_placement(component)
raise NotImplementedError(f"{component}")
[docs] def placement(self, name):
return self.placements[name]
[docs] def add_placement(self, placement):
assert isinstance(placement, plc.Placement)
assert placement.name not in self.placements
self.placements[placement.name] = placement
return placement
[docs] def maybe_add_placement(self, placement):
if placement.name in self.placements:
assert placement == self.placements[placement.name]
return placement
return self.add_placement(placement)
[docs] def find_operations_of_type(self, op_type):
return [op for op in self.operations.values() if isinstance(op, op_type)]
[docs] def operation(self, name):
return self.operations[name]
[docs] def add_operation(self, op):
assert isinstance(op, ops.Operation)
assert op.name not in self.operations, op
assert op.placement_name in self.placements, op.placement_name
self.operations[op.name] = op
return op
[docs] def maybe_add_operation(self, op):
assert isinstance(op, ops.Operation)
if op.name in self.operations:
assert op == self.operations[op.name]
return op
return self.add_operation(op)
[docs] def add_operations(self, operations):
for op in operations:
self.add_operation(op)
[docs] def remove_operation(self, name):
del self.operations[name]
[docs] def remove_operations(self, names):
for name in names:
self.remove_operation(name)
[docs] def rewire(self, old_op, new_op):
assert old_op.name in self.operations, old_op
assert new_op.name in self.operations, new_op
for op in self.operations.values():
for arg in op.inputs.keys():
if op.inputs[arg] == old_op.name:
op.inputs[arg] = new_op.name