Source code for pymoose.computation.utils

import re
from dataclasses import fields

import msgpack
import numpy as np

from pymoose.computation import computation as comp_base
from pymoose.computation import dtypes
from pymoose.computation import operations as ops
from pymoose.computation import placements as plc
from pymoose.computation import types as ty
from pymoose.computation import values
from pymoose.logger import get_logger

SUPPORTED_TYPES = [
    ops.AbsOperation,
    ops.AddNOperation,
    ops.AddOperation,
    ops.ArgmaxOperation,
    ops.AtLeast2DOperation,
    ops.BitwiseAndOperation,
    ops.BitwiseOrOperation,
    ops.CastOperation,
    ops.ConcatenateOperation,
    ops.ConstantOperation,
    ops.DecryptOperation,
    ops.DivOperation,
    ops.DotOperation,
    ops.ExpandDimsOperation,
    ops.ExpOperation,
    ops.GreaterOperation,
    ops.IdentityOperation,
    ops.IndexAxisOperation,
    ops.InputOperation,
    ops.InverseOperation,
    ops.LessOperation,
    ops.LoadOperation,
    ops.LogOperation,
    ops.Log2Operation,
    ops.MaximumOperation,
    ops.MeanOperation,
    ops.MulOperation,
    ops.MuxOperation,
    ops.OnesOperation,
    ops.ZerosOperation,
    ops.OutputOperation,
    ops.SigmoidOperation,
    ops.ReluOperation,
    ops.SelectOperation,
    ops.SoftmaxOperation,
    ops.ReshapeOperation,
    ops.SaveOperation,
    ops.ShapeOperation,
    ops.SliceOperation,
    ops.StridedSliceOperation,
    ops.SqueezeOperation,
    ops.SqrtOperation,
    ops.SubOperation,
    ops.SumOperation,
    ops.TransposeOperation,
    plc.HostPlacement,
    plc.ReplicatedPlacement,
    plc.MirroredPlacement,
    ty.AesKeyType,
    ty.AesTensorType,
    ty.BytesType,
    ty.FloatType,
    ty.IntType,
    ty.ShapeType,
    ty.StringType,
    ty.TensorType,
    ty.UnitType,
    ty.UnknownType,
    values.FloatConstant,
    values.IntConstant,
    values.ShapeConstant,
    values.StringConstant,
    values.TensorConstant,
]
TYPE_NAMES = {f"{ty.__name__}": ty for ty in SUPPORTED_TYPES}
FIXED_DTYPE_REGEX = re.compile("fixed([0-9]+)_([0-9]+)")


[docs]def serialize_computation(computation): return msgpack.packb(computation, default=_encode)
[docs]def deserialize_computation(bytes_stream): computation = msgpack.unpackb(bytes_stream, object_hook=_decode, raw=False) get_logger().debug(computation) return computation
def _encode(val): if isinstance(val, comp_base.Computation): return { "__type__": "Computation", "operations": val.operations, "placements": val.placements, } elif isinstance(val, (ops.Operation, ty.ValueType, plc.Placement, values.Value)): type_name = f"{type(val).__name__}" assert type_name in TYPE_NAMES, type_name d = {field.name: getattr(val, field.name) for field in fields(val)} d["__type__"] = type_name return d elif isinstance(val, ops.OpSignature): return { "__type__": "OpSignature", "input_types": val.input_types, "return_type": val.return_type, } elif isinstance(val, dtypes.DType): if FIXED_DTYPE_REGEX.match(val.name): return { "__type__": "DType", "name": "fixed", "integral_precision": val.integral_precision, "fractional_precision": val.fractional_precision, } return {"__type__": "DType", "name": val.name} elif isinstance(val, np.ndarray): return { "__type__": "ndarray", "dtype": str(val.dtype), "items": val.flatten().tolist(), "shape": list(val.shape), } elif isinstance(val, slice): return { "__type__": "PySlice", "start": val.start, "step": val.step, "stop": val.stop, } raise NotImplementedError(f"{type(val)}") def _decode(obj): if "__type__" in obj: if obj["__type__"] == "Computation": del obj["__type__"] return comp_base.Computation(**obj) elif obj["__type__"] == "DType": dtype_name = obj["name"] fixed_match = FIXED_DTYPE_REGEX.match(dtype_name) if fixed_match is not None: return dtypes.fixed( int(fixed_match.group(1)), int(fixed_match.group(2)) ) return { dtypes.int32.name: dtypes.int32, dtypes.int64.name: dtypes.int64, dtypes.uint32.name: dtypes.uint32, dtypes.uint64.name: dtypes.uint64, dtypes.float32.name: dtypes.float32, dtypes.float64.name: dtypes.float64, dtypes.bool_.name: dtypes.bool_, }[dtype_name] elif obj["__type__"] == "OpSignature": return ops.OpSignature( input_types=obj["input_types"], return_type=obj["return_type"], ) elif obj["__type__"] == "ndarray": dtype = obj["dtype"] shape = obj["shape"] contents = obj["items"] return np.array(contents, dtype=dtype).reshape(shape) else: ty = TYPE_NAMES[obj["__type__"]] del obj["__type__"] return ty(**obj) return obj