Source code for pymoose.edsl.base

import functools as ft
import inspect
import textwrap
from dataclasses import dataclass
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import numpy as np

from pymoose.computation import dtypes
from pymoose.computation import types as ty
from pymoose.computation import values

try:  # post python 3.10
    from types import EllipsisType
except ImportError:
    EllipsisType = type(...)

CURRENT_PLACEMENT: List = []

_NUMPY_DTYPES_MAP = {
    np.uint32: dtypes.uint32,
    np.dtype("uint32"): dtypes.uint32,
    np.uint64: dtypes.uint64,
    np.dtype("uint64"): dtypes.uint64,
    np.int32: dtypes.int32,
    np.dtype("int32"): dtypes.int32,
    np.int64: dtypes.int64,
    np.dtype("int64"): dtypes.int64,
    np.float32: dtypes.float32,
    np.dtype("float32"): dtypes.float32,
    np.float64: dtypes.float64,
    np.dtype("float64"): dtypes.float64,
    np.bool_: dtypes.bool_,
    np.dtype("bool_"): dtypes.bool_,
}

_CURRENT_RUNTIME = None


[docs]def get_current_runtime(): """Get a global runtime context.""" global _CURRENT_RUNTIME return _CURRENT_RUNTIME
[docs]def set_current_runtime(runtime): """Set a global runtime context.""" global _CURRENT_RUNTIME _CURRENT_RUNTIME = runtime
[docs]@dataclass class PlacementExpression: name: str def __enter__(self): global CURRENT_PLACEMENT CURRENT_PLACEMENT.append(self) def __exit__(self, type, value, traceback): global CURRENT_PLACEMENT CURRENT_PLACEMENT.pop(-1)
[docs]@dataclass class HostPlacementExpression(PlacementExpression): def __hash__(self): return hash(self.name)
[docs]@dataclass class MirroredPlacementExpression(PlacementExpression): players: List[PlacementExpression] def __hash__(self): return hash(self.name)
[docs]@dataclass class ReplicatedPlacementExpression(PlacementExpression): players: List[PlacementExpression] def __hash__(self): return hash(self.name)
[docs]def host_placement(name): return HostPlacementExpression(name=name)
[docs]def mirrored_placement(name, players): return MirroredPlacementExpression(name=name, players=players)
[docs]def replicated_placement(name, players): return ReplicatedPlacementExpression(name=name, players=players)
[docs]def get_current_placement(): global CURRENT_PLACEMENT return CURRENT_PLACEMENT[-1]
[docs]@dataclass(init=False) class Argument: """A type annotation Arguments to Moose computations. This class is used for annotating the parameters of Moose computations with extra information needed to identify their graph Values in the compiler, and eventually to materialize those values at runtime. Args: placement: A placement to pin this Argument to. The corresponding InputOp for the argument will be pinned to this placement. dtype: If the Value is a Tensor, specify its ``dtype``. Using this argument is equivalent to specifying ``vtype=TensorType(dtype=dtype)``. vtype: The Moose Value type for the Argument. This type information is used during compilation to check correctness of the computation and to lower the computation graph into one that is runtime-ready. The type information of arguments is also checked at runtime. """ placement: PlacementExpression dtype: Optional[dtypes.DType] = None vtype: Optional[ty.ValueType] = None
[docs] def __init__(self, placement, dtype=None, vtype=None): self.placement = placement self.dtype = dtype self.vtype = _maybe_lift_dtype_to_tensor_vtype(dtype, vtype)
[docs]@dataclass class Expression: placement: PlacementExpression inputs: List["Expression"] vtype: Optional[ty.ValueType] def __hash__(self): return id(self) # slicing sugar def __getitem__(self, slice_spec): # TODO explicitly construe placement from # global placement context and/or self.placement? assert isinstance(self.vtype, (ty.TensorType, ty.ShapeType, ty.AesTensorType)) assert isinstance(slice_spec, (slice, EllipsisType, list, tuple)) if isinstance(self.vtype, (ty.TensorType, ty.AesTensorType)): # turn single entry to a list of entries if isinstance(slice_spec, (slice, EllipsisType)): slice_spec = (slice_spec,) assert isinstance(slice_spec, (list, tuple)) slice_rewrite = [] for cur_slice in slice_spec: assert isinstance(cur_slice, (slice, EllipsisType)) if isinstance(cur_slice, EllipsisType): slice_rewrite.append(slice(None, None, None)) elif isinstance(cur_slice, slice): slice_rewrite.append(cur_slice) else: raise ValueError( "Indexing with other types different than Ellipsis and slice " "is not yet supported." ) return strided_slice(self, slices=slice_rewrite) elif isinstance(self.vtype, ty.ShapeType): if isinstance(slice_spec, (tuple, list)): if len(slice_spec) > 2: raise ValueError( "Indexing ShapeType requires a simple slice, including only " "`start` & `stop` slice values." ) begin, end = slice_spec assert isinstance(begin, int) and isinstance(end, int) elif isinstance(slice_spec, slice): if slice_spec.step is not None: raise ValueError( "Indexing ShapeType requires a simple slice, including only " "`start` & `stop` slice values." ) begin, end = slice_spec.start, slice_spec.stop return sliced(self, begin, end) else: raise IndexError(f"Expression of vtype {self.vtype} is not slice-able.") # arithmetic sugar def __neg__(self): _check_arithmetickable(self, "negate") if isinstance(self.vtype, ty.TensorType): if not self.vtype.dtype.is_signed: raise TypeError( f"Cannot negate Tensor of unsigned DType {self.vtype.dtype}." ) negative_one = constant(-1, vtype=self.vtype) return self.__rmul__(negative_one) def __abs__(self): _check_arithmetickable(self, "abs") if isinstance(self.vtype, ty.TensorType): if not self.vtype.dtype.is_signed: return self return abs(self) def __add__(self, other): return _binary_dunder_method(self, other, add, "add") def __radd__(self, other): return _binary_dunder_method(other, self, add, "add") def __iadd__(self, other): return _binary_dunder_method(self, other, add, "add") def __sub__(self, other): return _binary_dunder_method(self, other, sub, "subtract") def __rsub__(self, other): return _binary_dunder_method(other, self, sub, "subtract") def __isub__(self, other): return _binary_dunder_method(self, other, sub, "subtract") def __mul__(self, other): return _binary_dunder_method(self, other, mul, "multiply") def __rmul__(self, other): return _binary_dunder_method(other, self, mul, "multiply") def __imul__(self, other): return _binary_dunder_method(self, other, mul, "multiply") def __truediv__(self, other): return _binary_dunder_method(self, other, div, "divide") def __rtruediv__(self, other): return _binary_dunder_method(other, self, div, "divide") def __itruediv__(self, other): return _binary_dunder_method(self, other, div, "divide") def __matmul__(self, other): return _binary_dunder_method(self, other, dot, "dot-product") def __rmatmul__(self, other): return _binary_dunder_method(other, self, dot, "dot-product") def __imatmul__(self, other): return _binary_dunder_method(self, other, dot, "dot-product") def __gt__(self, other): return _binary_dunder_method(self, other, greater, "greater-than") def __lt__(self, other): return _binary_dunder_method(self, other, less, "less-than")
def _binary_dunder_method(x, y, fn, fn_desc): _check_arithmetickable(x, fn_desc) _check_arithmetickable(y, fn_desc) return fn(x, y) def _check_arithmetickable(expr, fn_name): if not isinstance(expr.vtype, (ty.TensorType, ty.FloatType, ty.IntType)): raise TypeError(f"Value of vtype {expr.vtype} is not {fn_name}-able.")
[docs]@dataclass class AddNExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class IdentityExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class ArgumentExpression(Expression): arg_name: str def __hash__(self): return id(self)
[docs]@dataclass class ConcatenateExpression(Expression): axis: Optional[int] def __hash__(self): return id(self)
[docs]@dataclass class MaximumExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class DecryptExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class ConstantExpression(Expression): value: Union[int, float] def __hash__(self): return id(self)
[docs]@dataclass class BinaryAndExpression(Expression): op_name: str def __hash__(self): return id(self)
[docs]@dataclass class BinaryOpExpression(Expression): op_name: str def __hash__(self): return id(self)
[docs]@dataclass class ExpandDimsExpression(Expression): axis: Tuple[int] def __hash__(self): return id(self)
[docs]@dataclass class SqueezeExpression(Expression): axis: Optional[Union[int, Tuple[int]]] def __hash__(self): return id(self)
[docs]@dataclass class OnesExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class ZerosExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class SquareExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class SumExpression(Expression): axis: Optional[Union[int, Tuple[int]]] def __hash__(self): return id(self)
[docs]@dataclass class MeanExpression(Expression): axis: Optional[Union[int, Tuple[int]]] def __hash__(self): return id(self)
[docs]@dataclass class ExpExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class SigmoidExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class SoftmaxExpression(Expression): axis: Optional[Union[int, Tuple[int]]] upmost_index: int def __hash__(self): return id(self)
[docs]@dataclass class ReluExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class ArgmaxExpression(Expression): axis: Optional[Union[int, Tuple[int]]] upmost_index: int def __hash__(self): return id(self)
[docs]@dataclass class LogExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class Log2Expression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class SqrtExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class TransposeExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class ReshapeExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class AtLeast2DExpression(Expression): to_column_vector: bool def __hash__(self): return id(self)
[docs]@dataclass class LoadExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class InverseExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class AbsExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class CastExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class SaveExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class ShapeExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class IndexAxisExpression(Expression): axis: int index: int def __hash__(self): return id(self)
[docs]@dataclass class SelectExpression(Expression): axis: int def __hash__(self): return id(self)
[docs]@dataclass class SliceExpression(Expression): begin: int end: int def __hash__(self): return id(self)
[docs]@dataclass class StridedSliceExpression(Expression): slices: Optional[Tuple[slice]] def __hash__(self): return id(self)
[docs]@dataclass class LessExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class GreaterExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class BitwiseAndExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class BitwiseOrExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class MuxExpression(Expression): def __hash__(self): return id(self)
[docs]@dataclass class OutputExpression(Expression): tag: str def __hash__(self): return id(self)
def _docinject_placement_arg(f): """Appends a ``placement`` entry to the Args section of a docstring. Docstring must be Google style. ``placement`` should be the last keyword argument, otherwise you should document it manually. Args: f: Callable whose docstring you want to amend. Returns: The function f with a modified docstring """ docstring = f.__doc__ lines = docstring.split("\n") in_args = False for i, line in enumerate(lines): if i == len(lines) - 1: break nextline = lines[i + 1] dedented_line = textwrap.dedent(line) line_indentation = len(line) - len(dedented_line) if dedented_line.startswith("Args:"): in_args = True arg_indentation = line_indentation + 4 continue if in_args and nextline == "": placement_arg_doc = textwrap.indent( "placement: An optional :class:`~pymoose.computation.placements.Placement` to pin this operation to.", " " * arg_indentation, ) lines.insert(i + 1, placement_arg_doc) in_args = False new_docstring = "\n".join(lines) f.__doc__ = new_docstring return f
[docs]@_docinject_placement_arg def add_n(arrays, placement=None): """Elementwise addition of a collection of tensors. Args: arrays: Tuple or list of tensors w/ identical shape and dtype. Returns: A tensor containing the elementwise sum of all input tensors. """ placement = _materialize_placement_arg(placement) if not isinstance(arrays, (tuple, list)): raise ValueError( "Inputs to `add_n` must be array-like, found argument " f"of type {type(arrays)}." ) input_vtype = arrays[0].vtype if isinstance(input_vtype, ty.TensorType): expected_vtype = input_vtype expected_dtype = input_vtype.dtype else: raise ValueError(f"Inputs must be have vtype TensorType, found {input_vtype}.") for array in arrays: if array.vtype != expected_vtype: raise ValueError( f"Inputs must be have vtype TensorType, found {array.vtype}." ) if array.vtype.dtype != expected_dtype: raise ValueError( f"Values passed to add_n must be same dtype: found {array.dtype} " f"and {expected_dtype} in value of `arrays` argument." ) return AddNExpression(placement=placement, inputs=arrays, vtype=input_vtype)
[docs]@_docinject_placement_arg def identity(x, placement=None): """The identity operation, ``f(x) = x``. Although the value of ``x`` is guaranteed to remain unchanged, note that its placement may differ. Args: x: Input value to be returned. Returns: The unchanged value ``x``. """ placement = _materialize_placement_arg(placement) return IdentityExpression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]@_docinject_placement_arg def concatenate(arrays, axis=0, placement=None): """Concatenation of a collection of tensors along a given axis/dimension. Note the tensors must have similar dtype and shape along all dimensions but the given one. The dimension must already exist in the input tensors; use :func:`expand_dims` on inputs to create a new dimnesion for concatenating along. Args: arrays: Tuple or List of tensors to be concatenated. axis: Optional integer representing the dimension to concatenate along. Dimension must already exist for all input tensors. """ placement = _materialize_placement_arg(placement) if not isinstance(arrays, (tuple, list)): raise ValueError( "Inputs to `concatenate` must be array-like, found argument " f"of type {type(arrays)}." ) input_vtype = arrays[0].vtype if isinstance(input_vtype, ty.TensorType): expected_vtype = input_vtype expected_dtype = input_vtype.dtype else: raise ValueError(f"Inputs must be have vtype TensorType, found {input_vtype}.") for array in arrays: if array.vtype != expected_vtype: raise ValueError( f"Inputs must be have vtype TensorType, found {array.vtype}." ) if array.vtype.dtype != expected_dtype: raise ValueError( f"Values passed to concatenate must be same dtype: found {array.dtype} " f"and {expected_dtype} in value of `arrays` argument." ) return ConcatenateExpression( placement=placement, inputs=arrays, axis=axis, vtype=input_vtype )
[docs]@_docinject_placement_arg def maximum(arrays, placement=None): """Elementwise maximum of a collection of input tensors. Args: arrays: Tuple or list of tensors w/ identical shape and dtype. Returns: A tensor containing the elementwise max of all input tensors. """ placement = _materialize_placement_arg(placement) if not isinstance(arrays, (tuple, list)): raise ValueError( "Inputs to `concatenate` must be array-like, found argument " f"of type {type(arrays)}." ) input_vtype = arrays[0].vtype if isinstance(input_vtype, ty.TensorType): expected_vtype = input_vtype expected_dtype = input_vtype.dtype else: raise ValueError(f"Inputs must be have vtype TensorType, found {input_vtype}.") for array in arrays: if array.vtype != expected_vtype: raise ValueError( f"Inputs must be have vtype TensorType, found {array.vtype}." ) if array.vtype.dtype != expected_dtype: raise ValueError( f"Values passed to maximum must be same dtype: found {array.dtype} " f"and {expected_dtype} in value of `arrays` argument." ) return MaximumExpression(placement=placement, inputs=arrays, vtype=input_vtype)
[docs]def decrypt(key, ciphertext, placement=None): placement = _materialize_placement_arg(placement) # key expr typecheck if not isinstance(key.vtype, ty.AesKeyType): raise ValueError( "Parameter `key` expected to be of type AesKeyType, found {key.vtype}." ) # ciphertext expr typecheck if not isinstance(ciphertext.vtype, ty.AesTensorType): raise ValueError( "Parameter `ciphertext` expected to be of type AesTensorType, " f"found {ciphertext.vtype}." ) # decrypt converts AesTensorType(fixed(i, f)) -> TensorType(fixed(i, f)) output_dtype = ciphertext.vtype.dtype output_type = ty.TensorType(output_dtype) return DecryptExpression( placement=placement, inputs=[key, ciphertext], vtype=output_type, )
[docs]@_docinject_placement_arg def constant(value, dtype=None, vtype=None, placement=None): """Embed a constant of a particular type. Args: value: The Python value to embed as constant. Supported types include the Python native variants of those listed in ``pymoose.computation.types``. Tensors are embeddable as numpy ndarrays. dtype: If given, coerces ``value`` into a Moose tensor with this dtype. Otherwise, if ``value`` is an ndarray, the Moose dtype is inferred from the ndarray's numpy dtype. vtype: If given, coerces ``value`` into a Moose value with this vtype, i.e. ValueType. Otherwise, infers the vtype from ``value``'s Python type. Returns: A Moose Value of type ``vtype`` embedded into the computation graph. """ placement = _materialize_placement_arg(placement) vtype = _maybe_lift_dtype_to_tensor_vtype(dtype, vtype) if isinstance(value, np.ndarray): moose_dtype = _NUMPY_DTYPES_MAP.get(value.dtype.type, None) if moose_dtype is None: raise NotImplementedError( f"Tensors of dtype `{value.dtype}` not supported as graph constants." ) if vtype is not None and moose_dtype != vtype.dtype: dtype = dtype or vtype.dtype if not isinstance(dtype, dtypes.DType): raise TypeError( "`dtype` argument to `constant` must be of type DType, " f"found {type(dtype)}." ) implicit_const = constant(value, dtype=moose_dtype, placement=placement) return cast(implicit_const, dtype, placement) elif vtype is None: vtype = ty.TensorType(moose_dtype) value = values.TensorConstant(value=value) elif isinstance(value, float): if isinstance(vtype, ty.TensorType) and vtype.dtype.is_fixedpoint: # want to use implicit casting, so simply wrap as ndarray and recurse return constant(np.array(value), vtype=vtype) value, vtype = _interpret_numeric_value(value, vtype, ty.FloatType()) elif isinstance(value, int): if isinstance(vtype, ty.TensorType) and vtype.dtype.is_fixedpoint: # want to use implicit casting, so simply wrap as ndarray and recurse return constant(np.array(value), vtype=vtype) value, vtype = _interpret_numeric_value(value, vtype, ty.IntType()) elif isinstance(value, str): vtype = vtype or ty.StringType() if not isinstance(vtype, ty.StringType): raise ValueError( "Constant value of type `str` does not match " f"user-supplied vtype argument `{vtype}`." ) value = values.StringConstant(value=value) return ConstantExpression(placement=placement, inputs=[], value=value, vtype=vtype)
[docs]@_docinject_placement_arg def add(lhs, rhs, placement=None): """Add two values. Equivalent to ``lhs + rhs`` for the two inputs. If the inputs are tensors, the addition is performed elementwise. Args: lhs: First addend. rhs: Second addend. Returns: Sum of the two inputs. """ assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) placement = _materialize_placement_arg(placement) vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "add") return BinaryOpExpression( op_name="add", placement=placement, inputs=[lhs, rhs], vtype=vtype )
[docs]@_docinject_placement_arg def sub(lhs, rhs, placement=None): """Subtract two values. Equivalent to ``lhs - rhs`` for the two inputs. If the inputs are tensors, the subtraction is performed elementwise. Args: lhs: Minuend. rhs: Subtrahend. Returns: Difference of the two inputs. """ assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) placement = _materialize_placement_arg(placement) vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "sub") return BinaryOpExpression( op_name="sub", placement=placement, inputs=[lhs, rhs], vtype=vtype )
[docs]@_docinject_placement_arg def mul(lhs, rhs, placement=None): """Multiply two values. Equivalent to ``lhs * rhs`` for the two inputs. If the inputs are tensors, the multiplication is performed elementwise. Args: lhs: First factor. rhs: Second factor. Returns: Product of the two inputs. """ assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) placement = _materialize_placement_arg(placement) vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "mul") return BinaryOpExpression( op_name="mul", placement=placement, inputs=[lhs, rhs], vtype=vtype )
[docs]@_docinject_placement_arg def dot(lhs, rhs, placement=None): """Dot product of two tensors. Tensor contraction along the second and first dimensions of the respective inputs. For 1 and 2 dimensional inputs, this is equivalent to ``np.dot(lhs, rhs)``. Args: lhs: Left-hand tensor factor. rhs: Right-hand tensor factor. Returns: Dot product of the two input tensors. """ assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) placement = _materialize_placement_arg(placement) vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "dot") return BinaryOpExpression( op_name="dot", placement=placement, inputs=[lhs, rhs], vtype=vtype )
[docs]@_docinject_placement_arg def div(lhs, rhs, placement=None): """Divide two tensors. Equivalent to ``lhs / rhs`` for the two inputs. If the inputs are tensors, the division is performed elementwise. Args: lhs: Dividend. rhs: Divisor. Returns: Quotient of the two inputs. """ assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) placement = _materialize_placement_arg(placement) vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "div") return BinaryOpExpression( op_name="div", placement=placement, inputs=[lhs, rhs], vtype=vtype )
[docs]@_docinject_placement_arg def less(lhs, rhs, placement=None): """Evaluate the boolean less-than operation, i.e. ``lhs < rhs``. If the inputs are tensors, the comparison is performed elementwise. Args: lhs: Left-hand side of comparison. rhs: Right-hand side of comparison. Returns: The comparison of the two inputs. """ assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) placement = _materialize_placement_arg(placement) return BinaryOpExpression( op_name="less", placement=placement, inputs=[lhs, rhs], vtype=ty.TensorType(dtype=dtypes.bool_), )
[docs]@_docinject_placement_arg def greater(lhs, rhs, placement=None): """Evaluate the boolean greater-than operation, i.e. ``lhs > rhs``. If tensors, the comparison is performed elementwise. Args: lhs: Left-hand side of comparison. rhs: Right-hand side of comparison. Returns: The comparison of the two inputs. """ assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) placement = _materialize_placement_arg(placement) return BinaryOpExpression( op_name="greater", placement=placement, inputs=[lhs, rhs], vtype=ty.TensorType(dtype=dtypes.bool_), )
[docs]@_docinject_placement_arg def logical_and(lhs, rhs, placement=None): """Evaluate the boolean AND operation, i.e. ``lhs & rhs``. If tensors, the operation is performed elementwise. Args: lhs: Left-hand side of operation. rhs: Right-hand side of operation. Returns: The logical intersection of the two inputs when treated as booleans. """ assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) placement = _materialize_placement_arg(placement) vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "and") return BinaryOpExpression( op_name="and", placement=placement, inputs=[lhs, rhs], vtype=vtype )
[docs]@_docinject_placement_arg def logical_or(lhs, rhs, placement=None): """Evaluate the boolean OR operation, i.e. ``lhs | rhs``. If tensors, the operation is performed elementwise. Args: lhs: Left-hand side of operation. rhs: Right-hand side of operation. Returns: The logical union of the two inputs when treated as booleans. """ assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) placement = _materialize_placement_arg(placement) vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "or") return BinaryOpExpression( op_name="or", placement=placement, inputs=[lhs, rhs], vtype=vtype )
[docs]def inverse(x, placement=None): """Invert a floating-point matrix. Args: x: A 2-dimensional float tensor to invert. Returns: The matrix inverse of ``x``. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) vtype = x.vtype if not isinstance(vtype, ty.TensorType): raise ValueError( "`inverse` operation only supports arguments of type TensorType." ) if vtype.dtype not in [dtypes.float32, dtypes.float64]: raise ValueError( "`inverse` operation only supports arguments of dtype `float32` or " "`float64`." ) return InverseExpression(placement=placement, inputs=[x], vtype=vtype)
[docs]@_docinject_placement_arg def expand_dims(x, axis, placement=None): """Expand the rank of ``x`` with new singleton dimensions. Args: x: The tensor to expand. axis: Index of the new dimension. If a tuple/list, should contain the indices where each new dimension will be expanded. Returns: The expanded tensor, complete with new singleton dimensions. """ assert isinstance(x, Expression) if isinstance(axis, (tuple, list)): for ax in axis: if not isinstance(ax, int): raise ValueError( "`axis` argument must be int or list/tuple of ints, found " f"{type(ax)}" ) elif isinstance(axis, int): axis = [axis] placement = _materialize_placement_arg(placement) return ExpandDimsExpression( placement=placement, inputs=[x], axis=axis, vtype=x.vtype )
[docs]@_docinject_placement_arg def squeeze(x, axis=None, placement=None): """Reduce out any singleton dimensions of ``x``. Args: x: The tensor from which to drop singleton dimensions. axis: Optional index into the shape of ``x`` denoting which singleton dimension to drop. If None, drops all singleton dimensions in ``x``. Returns: The squeezed tensor with fewer singleton dimensions. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return SqueezeExpression(placement=placement, inputs=[x], axis=axis, vtype=x.vtype)
[docs]@_docinject_placement_arg def ones(shape, dtype, placement=None): """Embed a ones array into the Moose computation graph. Equivalent to `pm.constant(np.ones(shape, dtype))` for a given ``dtype`` and ``shape``. Args: shape: Shape of the ones array. dtype: Dtype of the ones array. Returns: A tensor of all 1s with given shape and dtype. """ assert isinstance(shape, Expression) placement = _materialize_placement_arg(placement) if isinstance(shape, (list, tuple)): # TODO (Yann) Currently we only have the ability to declare HostShape # as constant. We should add the ability to declare RepShape as constant. if isinstance(placement, ReplicatedPlacementExpression): host_placement = placement.players[0] else: host_placement = placement shape = constant( values.ShapeConstant(value=shape), vtype=ty.ShapeType(), placement=host_placement, ) vtype = ty.TensorType(dtype) return OnesExpression(placement=placement, inputs=[shape], vtype=vtype)
[docs]@_docinject_placement_arg def zeros(shape, dtype, placement=None): """Embed a zeros array into the Moose computation graph. Equivalent to `pm.constant(np.zeros(shape, dtype))` for a given ``dtype`` and ``shape``. Args: shape: Shape of the zeros array. dtype: Dtype of the zeros array. Returns: A tensor of all 0s with given shape and dtype. """ assert isinstance(shape, Expression) placement = _materialize_placement_arg(placement) if isinstance(shape, (list, tuple)): # TODO (Yann) Currently we only have the ability to declare HostShape # as constant. We should add the ability to declare RepShape as constant. if isinstance(placement, ReplicatedPlacementExpression): host_placement = placement.players[0] else: host_placement = placement shape = constant( values.ShapeConstant(value=shape), vtype=ty.ShapeType(), placement=host_placement, ) vtype = ty.TensorType(dtype) return ZerosExpression(placement=placement, inputs=[shape], vtype=vtype)
[docs]@_docinject_placement_arg def square(x, placement=None): """Square an input value. If the input is a tensor, the operation is performed elementwise. Args: x: A value to square. Returns: The squared input. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return mul(x, x, placement=placement)
[docs]@_docinject_placement_arg def sum(x, axis=None, placement=None): """Sum-reduce an input tensor. Computes the sum of tensor elements along a particular axis, or for the entire tensor if ``axis=None``. Args: x: A tensor. axis: An optional dimension along which to sum-reduce the tensor. If None, sum-reduces the entire tensor and outputs a scalar tensor. Returns: The summed input, with one or all dimensions reduced out. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return SumExpression(placement=placement, inputs=[x], axis=axis, vtype=x.vtype)
[docs]@_docinject_placement_arg def mean(x, axis=None, placement=None): """Mean-reduce an input tensor. Args: x: A tensor. axis: An optional dimension along which to mean-reduce the tensor. If None, mean-reduces the entire tensor and outputs a scalar tensor. Returns: The averaged input, with one or all dimensions reduced out. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return MeanExpression(placement=placement, inputs=[x], axis=axis, vtype=x.vtype)
[docs]@_docinject_placement_arg def exp(x, placement=None): """Elementiwise exponential function: :math:`e^x`. If the input is a tensor, the operation is performed elementwise. Args: x: A value. Returns: The exponentiated input. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return ExpExpression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]@_docinject_placement_arg def sqrt(x, placement=None): """Compute the square-root of a value :math:`\sqrt{x}`. If the input is a tensor, the operation is performed elementwise. Args: x: A value. Returns: The square-root of the input. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return SqrtExpression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]@_docinject_placement_arg def sigmoid(x, placement=None): r"""Apply the sigmoid function, :math:`\frac{1}{1 + e^{-x}}`. If the input is a tensor, the operation is performed elementwise. Args: x: A value. Returns: The result of applying the sigmoid function to ``x``. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return SigmoidExpression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]@_docinject_placement_arg def relu(x, placement=None): """Apply the rectified-linear unit (ReLU) function, ``f(x) = max(0, x)``. If the input is a tensor, the operation is performed elementwise. Args: x: A value. Returns: A value correspoding to ``ReLU(x)``. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return ReluExpression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]@_docinject_placement_arg def softmax(x, axis, upmost_index, placement=None): r"""Softmax function. .. math :: \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} Args: x: A tensor. axis: The dimension along which the softmax divisor's sum-reduce should be computed. upmost_index: The max index that should be used for computing the softmax divisor's sum-reduce. Generally, this should be the size of the ``axis`` dimension of ``x``. Returns: Result of applying the softmax function on ``x`` along the given ``axis``. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return SoftmaxExpression( placement=placement, inputs=[x], axis=axis, upmost_index=upmost_index, vtype=x.vtype, )
[docs]@_docinject_placement_arg def argmax(x, axis, upmost_index, placement=None): """Compute the index of the maximal element of a tensor along a given dimension. Args: x: A tensor. axis: The dimension along which to compute the argmax. upmost_index. upmost_index: The max index that should be considered for computing the argmax. Generally, this should be the size of the ``axis`` dimension of ``x``. Returns: A dimension-reduced tensor representing the argmax of ``x`` along ``axis``. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return ArgmaxExpression( placement=placement, inputs=[x], axis=axis, upmost_index=upmost_index, vtype=x.vtype, )
[docs]@_docinject_placement_arg def log(x, placement=None): """Compute the elementwise natural logarithm of a tensor. Args: x: A tensor. Returns: The tensor representing ``log(x)``. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return LogExpression( placement=placement, inputs=[x], vtype=x.vtype, )
[docs]@_docinject_placement_arg def log2(x, placement=None): """Compute the elementwise base-2 logarithm of a tensor. Args: x: A tensor. Returns: The tensor representing ``log_2(x)``. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return Log2Expression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]@_docinject_placement_arg def shape(x, placement=None): """Compute the shape of an input tensor. Args: x: A tensor. Returns: A Moose value of type ``ShapeType`` corresponding to the input tensor's shape. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return ShapeExpression(placement=placement, inputs=[x], vtype=ty.ShapeType())
[docs]@_docinject_placement_arg def index_axis(x, axis, index, placement=None): """Index a tensor along a given dimension. Args: x: A tensor. axis: The dimension along which to index. Returns: A slice of ``x`` corresponding to the sub-tensor at ``index`` along ``axis``. """ assert isinstance(x, Expression) if not isinstance(axis, int) or index < 0: raise ValueError( "`axis` argument must be int greater or equal to 0, found " f"{axis} of type {type(axis)}" ) if not isinstance(index, int) or index < 0: raise ValueError( "`index` argument must be int greater or equal to 0, found " f"{index} of type {type(index)}" ) placement = _materialize_placement_arg(placement) return IndexAxisExpression( placement=placement, inputs=[x], axis=axis, index=index, vtype=x.vtype )
[docs]@_docinject_placement_arg def select(x, axis, index, placement=None): """Select elements along some axis of a tensor according to some index tensor. Args: x: A tensor. axis: The dimension along which to index. index: A 1-d boolean tensor such that ``len(index) == x.shape[axis]``. Returns: Copy of the tensor ``x``, with the values along `axis` filtered such that ``x[..., i, ...]`` is kept iff ``index[i] == 1``. """ # TODO (Yann) extend kernels to support tuple of ints for axis # and multiple index assert isinstance(x, Expression) assert isinstance(index, Expression) if not isinstance(axis, int): raise ValueError( "`axis` argument must be int greater or equal to 0, found " f"{axis} of type {type(axis)}" ) placement = _materialize_placement_arg(placement) return SelectExpression( placement=placement, inputs=[x, index], axis=axis, vtype=x.vtype )
[docs]@_docinject_placement_arg def sliced(x, begin, end, placement=None): """Compute a slice of a value. Args: x: A slice-able value, e.g. a TensorType or ShapeType. begin: Index to start the slice at, e.g. 1 in ``x[1:5]``. end: Non-inclusive index to end the slice at, e.g. 5 in ``x[1:5]``. Returns: The sliced value. """ assert isinstance(x, Expression) assert isinstance(begin, int) assert isinstance(end, int) placement = _materialize_placement_arg(placement) return SliceExpression( placement=placement, inputs=[x], begin=begin, end=end, vtype=x.vtype )
# TODO(jvmncs): better docstring
[docs]@_docinject_placement_arg def strided_slice(x, slices, placement=None): """Compute a strided slice of a value. This version is more general than ``sliced``, as it has support for more complex variants of Python slices. Args: x: A value. slices: A list/tuple of Python `slice` objects. Returns: The sliced value. """ assert isinstance(x, Expression) assert isinstance(slices, (tuple, list)) placement = _materialize_placement_arg(placement) for s in slices: if not isinstance(s, slice): raise ValueError( "`slices` argument must a list/tuple of slices, found " f"{type(s)}" ) return StridedSliceExpression( placement=placement, inputs=[x], slices=slices, vtype=x.vtype )
[docs]@_docinject_placement_arg def transpose(x, placement=None): """Compute the transpose of a tensor. This is equivalent to numpy.ndarray.T. Args: x: A tensor. Returns: The input tensor with dimensions in reverse order. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return TransposeExpression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]def atleast_2d(x, to_column_vector=False, placement=None): assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return AtLeast2DExpression( placement=placement, inputs=[x], to_column_vector=to_column_vector, vtype=x.vtype, )
[docs]@_docinject_placement_arg def reshape(x, shape, placement=None): """Reshape a tensor. Broadcasting is not allowed; the new shape must have the same total number of elements. Args: x: A tensor. shape: A list, tuple, or Shape dictating the new shape of the tensor. Returns: The reshaped tensor. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) if isinstance(shape, (list, tuple)): # TODO (Yann) Currently we only have the ability to declare HostShape # as constant. We should add the ability to declare RepShape as constant. if isinstance(placement, ReplicatedPlacementExpression): host_placement = placement.players[0] else: host_placement = placement shape = constant( values.ShapeConstant(value=shape), vtype=ty.ShapeType(), placement=host_placement, ) assert isinstance(shape, Expression) return ReshapeExpression(placement=placement, inputs=[x, shape], vtype=x.vtype)
[docs]@_docinject_placement_arg def abs(x, placement=None): """Compute the absolute value. If the input is a tensor. the operation is performed elementwise. Args: x: A value Returns: The absolute value of ``x``. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) return AbsExpression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]@_docinject_placement_arg def mux(selector, x, y, placement=None): """Multiplex two tensors according to some condition. This op allows for static control-flow of elements coming from two tensors. For boolean tensor ``s`` and arbitrary tensors ``x`` and ``y``, the operation is equivalent to ``s * (x - y) + y``. Args: selector: Boolean tensor representing the control-flow condition. x: A tensor to fill from when the selector condition is 1. y: A tensor to fill from when the selector condition is 0. Result: A tensor with elements of ``x`` and ``y`` multiplexed according to the condition given by ``selector``. """ assert isinstance(selector, Expression) assert isinstance(selector.vtype, ty.TensorType) assert selector.vtype.dtype.is_boolean, selector.vtype.dtype assert isinstance(x, Expression) assert isinstance(x.vtype, ty.TensorType), x.vtype assert x.vtype.dtype.is_fixedpoint, x.vtype.dtype assert isinstance(y, Expression) assert isinstance(y.vtype, ty.TensorType), y.vtype assert y.vtype.dtype.is_fixedpoint, y.vtype.dtype placement = _materialize_placement_arg(placement) assert isinstance(placement, ReplicatedPlacementExpression) vtype = _assimilate_arg_vtypes(x.vtype, y.vtype, "mux") return MuxExpression(placement=placement, inputs=[selector, x, y], vtype=vtype)
[docs]@_docinject_placement_arg def cast(x, dtype, placement=None): """Cast a tensor to a new dtype. Args: x: A tensor. dtype: A DType. Returns: The tensor ``x`` converted to ``dtype``. """ assert isinstance(x, Expression) placement = _materialize_placement_arg(placement) if not isinstance(x.vtype, ty.TensorType): raise ValueError( f"Argument to `cast` operation must be tensor, found {x.vtype}." ) # Check dtype args are well-defined if x.vtype.dtype is None: raise ValueError( "Argument to `cast` function must have well-defined dtype; " "found value with dtype=None." ) elif dtype is None: raise ValueError( "Invalid `dtype` argument to `cast` function: cannot cast to dtype=None." ) # Ensure value can be cast by compiler/executor into the well-defined dtype arg if isinstance(dtype, dtypes.DType): moose_dtype = dtype elif dtype in _NUMPY_DTYPES_MAP: moose_dtype = _NUMPY_DTYPES_MAP[dtype] else: raise ValueError( "Unsupported dtype arg in `cast` function: expected argument " f"of type DType, found type {type(dtype)}." ) if x.vtype.dtype == moose_dtype: # This is a no-op return x return CastExpression( placement=placement, inputs=[x], vtype=ty.TensorType(moose_dtype) )
[docs]@_docinject_placement_arg def load(key, query="", dtype=None, vtype=None, placement=None): """Load a value from placement storage. The underlying key-value store implementation is runtime-specific. Generally, assume that each HostPlacement can be using a different storage implementation in its Moose worker/executor. Args: key: A string or Moose String corresponding to the value that should be loaded. query: An optional query string/String to provide to executor storage. Most common storage implementations ignore this. dtype: If value should be loaded as a tensor, the DType to coerce the tensor to. If None, inferred from the value's numpy dtype. vtype: The Moose type to coerce the loaded value into. If None, will be traced as :class:`~pymoose.computation.types.UnknownType` and the compiler will attempt to fill it in during its initial Typing pass. Returns: The loaded value, as provided by the worker backing the HostPlacement. """ placement = _materialize_placement_arg(placement) vtype = _maybe_lift_dtype_to_tensor_vtype(dtype, vtype) if isinstance(key, str): key = constant(key, placement=placement, vtype=ty.StringType()) elif isinstance(key, Argument) and key.vtype not in [ty.StringType(), None]: raise ValueError( f"Function 'edsl.load' encountered `key` argument of vtype {key.vtype}; " "expected `StringType`." ) elif not isinstance(key, Expression): raise ValueError( f"Function 'edsl.load' encountered `key` argument of type {type(key)}; " "expected one of: string, ConstantExpression, or Argument." ) if isinstance(query, str): query = constant(query, placement=placement, vtype=ty.StringType()) elif isinstance(query, Argument) and query.vtype not in [ty.StringType(), None]: raise ValueError( f"Function 'edsl.load' encountered `query` argument of " f"vtype {query.vtype}; expected 'StringType'." ) elif not isinstance(query, Expression): raise ValueError( f"Function 'edsl.load' encountered `query` argument of type {type(query)}; " "expected one of: string, ConstantExpression, or Argument." ) return LoadExpression(placement=placement, inputs=[key, query], vtype=vtype)
[docs]@_docinject_placement_arg def save(key, value, placement=None): """Save a key-value pair to placement storage. The underlying key-value store implementation is runtime-specific. Generally, assume that each HostPlacement can be using a different storage implementation in its Moose worker/executor. Args: key: A string or Moose String. value: A Moose Value. Returns: A Moose Value of type Unit. """ assert isinstance(value, Expression) placement = _materialize_placement_arg(placement) if isinstance(key, str): key = constant(key, placement=placement, vtype=ty.StringType()) elif isinstance(key, Argument) and key.vtype not in [ty.StringType(), None]: raise ValueError( f"Function 'edsl.save' encountered `key` argument of type {key.vtype}; " "expected 'StringType'." ) elif not isinstance(key, Expression): raise ValueError( f"Function 'edsl.save' encountered `key` argument of type {type(key)}; " "expected one of: string, ConstantExpression, or ArgumentExpression." ) return SaveExpression(placement=placement, inputs=[key, value], vtype=None)
[docs]@_docinject_placement_arg def output(tag, value, placement=None): """Tag an output of a computation. This op is similar to ``identity``, but additionally tags its value. It can be used to pin an output to a particular placement without writing it to placement storage. It's also useful for maintaining the ordering of computation outputs in PyMoose, since Moose compilation generally doesn't preserve output order. Args: tag: A tag to associate with the output, useful for reconstructing the original order of outputs. value: The Moose value to output. Returns: The tagged value. """ assert isinstance(value, Expression) assert isinstance(tag, str) placement = _materialize_placement_arg(placement) return OutputExpression( placement=placement, inputs=[value], vtype=value.vtype, tag=tag )
[docs]def computation(func=None, role_map=None): """Annotates a Python function as a Moose computation. Args: func: A Callable. role_map: A map of abstract placements to identities in the current runtime context. Returns: An abstract Moose computation that can be invoked in a runtime context. """ if func is None: return ft.partial(computation, role_map=role_map) return AbstractComputation(func, role_map)
[docs]class AbstractComputation: def __init__(self, func, role_map): if not callable(func): raise TypeError( f"Argument `func` should be a callable, but found {type(func)}." ) self.func = func if role_map is not None and not isinstance(role_map, dict): raise TypeError( "Argument `role_map` should be map of placement names to placement " f"names, found {type(role_map)}." ) self.role_map = role_map def __call__(self, *args, **kwargs): func_signature = inspect.signature(self.func) arg_names = [arg_name for arg_name, _ in func_signature.parameters.items()] arguments = {} # add values from `args` for arg_i, arg_val in enumerate(args): if arg_i >= len(arg_names): raise ValueError(f"Too many arguments for `{self.func.__name__}`") arg_name = arg_names[arg_i] arguments[arg_name] = arg_val # add values from `kwargs` for arg_name, arg_val in kwargs.items(): if arg_name in arguments: raise ValueError( f"Argument `{arg_name}` given more than once to " f"`{self.func.__name__}`" ) arguments[arg_name] = arg_val # check that all arguments were given for arg_name in arg_names: if arg_name not in arguments: raise ValueError( f"Missing argument `{arg_name}` in call to `{self.func.__name__}`" ) # check that no extra arguments were given # NOTE we could potentially leave out this check for arg_name in arguments.keys(): if arg_name not in arg_names: raise ValueError( f"Argument `{arg_name}` is not used by `{self.func.__name__}`" ) runtime = get_current_runtime() if not runtime: raise RuntimeError("No default runtime found") return runtime.evaluate_computation(self, arguments)
[docs] def with_role_map(self, role_map): return self.__class__(self.func, role_map)
def _assimilate_arg_dtypes(lhs_vtype, rhs_vtype, fn_name): lhs_dtype = lhs_vtype.dtype rhs_dtype = rhs_vtype.dtype if lhs_dtype != rhs_dtype: raise ValueError( f"Function `{fn_name}` expected arguments of similar dtype: " f"found mismatched dtypes `{lhs_dtype}` and `{rhs_dtype}`." ) return lhs_vtype def _assimilate_arg_vtypes(lhs_vtype, rhs_vtype, fn_name): if isinstance(lhs_vtype, ty.TensorType) and isinstance(rhs_vtype, ty.TensorType): return _assimilate_arg_dtypes(lhs_vtype, rhs_vtype, fn_name) if lhs_vtype != rhs_vtype: raise ValueError( f"Function `{fn_name}` expected arguments of similar type: " f"found mismatched types `{lhs_vtype}` and `{rhs_vtype}`." ) return lhs_vtype def _check_tensor_type_arg_consistency(dtype, vtype): if isinstance(vtype, ty.TensorType) and vtype.dtype != dtype: raise ValueError( f"Inconsistent type information for tensor: dtype {dtype} is " f"inconsistent with tensor type {vtype}." ) def _materialize_placement_arg(plc): plc = plc or get_current_placement() if not isinstance(plc, PlacementExpression): raise TypeError(f"Expected value of type Placement, found {type(plc)}.") return plc def _maybe_lift_dtype_to_tensor_vtype(dtype, vtype): if dtype is None and vtype is None: return elif vtype is None and dtype is not None: return ty.TensorType(dtype) elif vtype is not None and dtype is not None: _check_tensor_type_arg_consistency(dtype, vtype) return vtype else: # vtype but no dtype return vtype def _interpret_numeric_value(value, vtype, fallback_vtype): assert isinstance(value, (int, float)) if vtype is None: vtype = fallback_vtype if isinstance(vtype, ty.TensorType): dtype = vtype.dtype if not dtype.is_float and not dtype.is_integer: raise TypeError(f"Cannot interpret scalar constant as dtype {dtype}.") value = values.TensorConstant(np.array(value, dtype=dtype.numpy_dtype)) elif isinstance(vtype, ty.FloatType): value = values.FloatConstant(value) elif isinstance(vtype, ty.IntType): value = values.IntConstant(value) else: raise TypeError( "Cannot interpret numeric constant as non-numeric type {vtype}." ) return value, vtype