import functools as ft
import inspect
import textwrap
from dataclasses import dataclass
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
from pymoose.computation import dtypes
from pymoose.computation import types as ty
from pymoose.computation import values
try: # post python 3.10
from types import EllipsisType
except ImportError:
EllipsisType = type(...)
CURRENT_PLACEMENT: List = []
_NUMPY_DTYPES_MAP = {
np.uint32: dtypes.uint32,
np.dtype("uint32"): dtypes.uint32,
np.uint64: dtypes.uint64,
np.dtype("uint64"): dtypes.uint64,
np.int32: dtypes.int32,
np.dtype("int32"): dtypes.int32,
np.int64: dtypes.int64,
np.dtype("int64"): dtypes.int64,
np.float32: dtypes.float32,
np.dtype("float32"): dtypes.float32,
np.float64: dtypes.float64,
np.dtype("float64"): dtypes.float64,
np.bool_: dtypes.bool_,
np.dtype("bool_"): dtypes.bool_,
}
_CURRENT_RUNTIME = None
[docs]def get_current_runtime():
"""Get a global runtime context."""
global _CURRENT_RUNTIME
return _CURRENT_RUNTIME
[docs]def set_current_runtime(runtime):
"""Set a global runtime context."""
global _CURRENT_RUNTIME
_CURRENT_RUNTIME = runtime
[docs]@dataclass
class PlacementExpression:
name: str
def __enter__(self):
global CURRENT_PLACEMENT
CURRENT_PLACEMENT.append(self)
def __exit__(self, type, value, traceback):
global CURRENT_PLACEMENT
CURRENT_PLACEMENT.pop(-1)
[docs]@dataclass
class HostPlacementExpression(PlacementExpression):
def __hash__(self):
return hash(self.name)
[docs]@dataclass
class MirroredPlacementExpression(PlacementExpression):
players: List[PlacementExpression]
def __hash__(self):
return hash(self.name)
[docs]@dataclass
class ReplicatedPlacementExpression(PlacementExpression):
players: List[PlacementExpression]
def __hash__(self):
return hash(self.name)
[docs]def host_placement(name):
return HostPlacementExpression(name=name)
[docs]def mirrored_placement(name, players):
return MirroredPlacementExpression(name=name, players=players)
[docs]def replicated_placement(name, players):
return ReplicatedPlacementExpression(name=name, players=players)
[docs]def get_current_placement():
global CURRENT_PLACEMENT
return CURRENT_PLACEMENT[-1]
[docs]@dataclass(init=False)
class Argument:
"""A type annotation Arguments to Moose computations.
This class is used for annotating the parameters of Moose computations with extra
information needed to identify their graph Values in the compiler, and eventually to
materialize those values at runtime.
Args:
placement: A placement to pin this Argument to. The corresponding InputOp for
the argument will be pinned to this placement.
dtype: If the Value is a Tensor, specify its ``dtype``. Using this argument is
equivalent to specifying ``vtype=TensorType(dtype=dtype)``.
vtype: The Moose Value type for the Argument. This type information is used
during compilation to check correctness of the computation and to lower the
computation graph into one that is runtime-ready. The type information of
arguments is also checked at runtime.
"""
placement: PlacementExpression
dtype: Optional[dtypes.DType] = None
vtype: Optional[ty.ValueType] = None
[docs] def __init__(self, placement, dtype=None, vtype=None):
self.placement = placement
self.dtype = dtype
self.vtype = _maybe_lift_dtype_to_tensor_vtype(dtype, vtype)
[docs]@dataclass
class Expression:
placement: PlacementExpression
inputs: List["Expression"]
vtype: Optional[ty.ValueType]
def __hash__(self):
return id(self)
# slicing sugar
def __getitem__(self, slice_spec):
# TODO explicitly construe placement from
# global placement context and/or self.placement?
assert isinstance(self.vtype, (ty.TensorType, ty.ShapeType, ty.AesTensorType))
assert isinstance(slice_spec, (slice, EllipsisType, list, tuple))
if isinstance(self.vtype, (ty.TensorType, ty.AesTensorType)):
# turn single entry to a list of entries
if isinstance(slice_spec, (slice, EllipsisType)):
slice_spec = (slice_spec,)
assert isinstance(slice_spec, (list, tuple))
slice_rewrite = []
for cur_slice in slice_spec:
assert isinstance(cur_slice, (slice, EllipsisType))
if isinstance(cur_slice, EllipsisType):
slice_rewrite.append(slice(None, None, None))
elif isinstance(cur_slice, slice):
slice_rewrite.append(cur_slice)
else:
raise ValueError(
"Indexing with other types different than Ellipsis and slice "
"is not yet supported."
)
return strided_slice(self, slices=slice_rewrite)
elif isinstance(self.vtype, ty.ShapeType):
if isinstance(slice_spec, (tuple, list)):
if len(slice_spec) > 2:
raise ValueError(
"Indexing ShapeType requires a simple slice, including only "
"`start` & `stop` slice values."
)
begin, end = slice_spec
assert isinstance(begin, int) and isinstance(end, int)
elif isinstance(slice_spec, slice):
if slice_spec.step is not None:
raise ValueError(
"Indexing ShapeType requires a simple slice, including only "
"`start` & `stop` slice values."
)
begin, end = slice_spec.start, slice_spec.stop
return sliced(self, begin, end)
else:
raise IndexError(f"Expression of vtype {self.vtype} is not slice-able.")
# arithmetic sugar
def __neg__(self):
_check_arithmetickable(self, "negate")
if isinstance(self.vtype, ty.TensorType):
if not self.vtype.dtype.is_signed:
raise TypeError(
f"Cannot negate Tensor of unsigned DType {self.vtype.dtype}."
)
negative_one = constant(-1, vtype=self.vtype)
return self.__rmul__(negative_one)
def __abs__(self):
_check_arithmetickable(self, "abs")
if isinstance(self.vtype, ty.TensorType):
if not self.vtype.dtype.is_signed:
return self
return abs(self)
def __add__(self, other):
return _binary_dunder_method(self, other, add, "add")
def __radd__(self, other):
return _binary_dunder_method(other, self, add, "add")
def __iadd__(self, other):
return _binary_dunder_method(self, other, add, "add")
def __sub__(self, other):
return _binary_dunder_method(self, other, sub, "subtract")
def __rsub__(self, other):
return _binary_dunder_method(other, self, sub, "subtract")
def __isub__(self, other):
return _binary_dunder_method(self, other, sub, "subtract")
def __mul__(self, other):
return _binary_dunder_method(self, other, mul, "multiply")
def __rmul__(self, other):
return _binary_dunder_method(other, self, mul, "multiply")
def __imul__(self, other):
return _binary_dunder_method(self, other, mul, "multiply")
def __truediv__(self, other):
return _binary_dunder_method(self, other, div, "divide")
def __rtruediv__(self, other):
return _binary_dunder_method(other, self, div, "divide")
def __itruediv__(self, other):
return _binary_dunder_method(self, other, div, "divide")
def __matmul__(self, other):
return _binary_dunder_method(self, other, dot, "dot-product")
def __rmatmul__(self, other):
return _binary_dunder_method(other, self, dot, "dot-product")
def __imatmul__(self, other):
return _binary_dunder_method(self, other, dot, "dot-product")
def __gt__(self, other):
return _binary_dunder_method(self, other, greater, "greater-than")
def __lt__(self, other):
return _binary_dunder_method(self, other, less, "less-than")
def _binary_dunder_method(x, y, fn, fn_desc):
_check_arithmetickable(x, fn_desc)
_check_arithmetickable(y, fn_desc)
return fn(x, y)
def _check_arithmetickable(expr, fn_name):
if not isinstance(expr.vtype, (ty.TensorType, ty.FloatType, ty.IntType)):
raise TypeError(f"Value of vtype {expr.vtype} is not {fn_name}-able.")
[docs]@dataclass
class AddNExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class IdentityExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class ArgumentExpression(Expression):
arg_name: str
def __hash__(self):
return id(self)
[docs]@dataclass
class ConcatenateExpression(Expression):
axis: Optional[int]
def __hash__(self):
return id(self)
[docs]@dataclass
class MaximumExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class DecryptExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class ConstantExpression(Expression):
value: Union[int, float]
def __hash__(self):
return id(self)
[docs]@dataclass
class BinaryAndExpression(Expression):
op_name: str
def __hash__(self):
return id(self)
[docs]@dataclass
class BinaryOpExpression(Expression):
op_name: str
def __hash__(self):
return id(self)
[docs]@dataclass
class ExpandDimsExpression(Expression):
axis: Tuple[int]
def __hash__(self):
return id(self)
[docs]@dataclass
class SqueezeExpression(Expression):
axis: Optional[Union[int, Tuple[int]]]
def __hash__(self):
return id(self)
[docs]@dataclass
class OnesExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class ZerosExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class SquareExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class SumExpression(Expression):
axis: Optional[Union[int, Tuple[int]]]
def __hash__(self):
return id(self)
[docs]@dataclass
class MeanExpression(Expression):
axis: Optional[Union[int, Tuple[int]]]
def __hash__(self):
return id(self)
[docs]@dataclass
class ExpExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class SigmoidExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class SoftmaxExpression(Expression):
axis: Optional[Union[int, Tuple[int]]]
upmost_index: int
def __hash__(self):
return id(self)
[docs]@dataclass
class ReluExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class ArgmaxExpression(Expression):
axis: Optional[Union[int, Tuple[int]]]
upmost_index: int
def __hash__(self):
return id(self)
[docs]@dataclass
class LogExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class Log2Expression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class SqrtExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class TransposeExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class ReshapeExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class AtLeast2DExpression(Expression):
to_column_vector: bool
def __hash__(self):
return id(self)
[docs]@dataclass
class LoadExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class InverseExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class AbsExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class CastExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class SaveExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class ShapeExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class IndexAxisExpression(Expression):
axis: int
index: int
def __hash__(self):
return id(self)
[docs]@dataclass
class SelectExpression(Expression):
axis: int
def __hash__(self):
return id(self)
[docs]@dataclass
class SliceExpression(Expression):
begin: int
end: int
def __hash__(self):
return id(self)
[docs]@dataclass
class StridedSliceExpression(Expression):
slices: Optional[Tuple[slice]]
def __hash__(self):
return id(self)
[docs]@dataclass
class LessExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class GreaterExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class BitwiseAndExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class BitwiseOrExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class MuxExpression(Expression):
def __hash__(self):
return id(self)
[docs]@dataclass
class OutputExpression(Expression):
tag: str
def __hash__(self):
return id(self)
def _docinject_placement_arg(f):
"""Appends a ``placement`` entry to the Args section of a docstring.
Docstring must be Google style. ``placement`` should be the last keyword
argument, otherwise you should document it manually.
Args:
f: Callable whose docstring you want to amend.
Returns:
The function f with a modified docstring
"""
docstring = f.__doc__
lines = docstring.split("\n")
in_args = False
for i, line in enumerate(lines):
if i == len(lines) - 1:
break
nextline = lines[i + 1]
dedented_line = textwrap.dedent(line)
line_indentation = len(line) - len(dedented_line)
if dedented_line.startswith("Args:"):
in_args = True
arg_indentation = line_indentation + 4
continue
if in_args and nextline == "":
placement_arg_doc = textwrap.indent(
"placement: An optional :class:`~pymoose.computation.placements.Placement` to pin this operation to.",
" " * arg_indentation,
)
lines.insert(i + 1, placement_arg_doc)
in_args = False
new_docstring = "\n".join(lines)
f.__doc__ = new_docstring
return f
[docs]@_docinject_placement_arg
def add_n(arrays, placement=None):
"""Elementwise addition of a collection of tensors.
Args:
arrays: Tuple or list of tensors w/ identical shape and dtype.
Returns:
A tensor containing the elementwise sum of all input tensors.
"""
placement = _materialize_placement_arg(placement)
if not isinstance(arrays, (tuple, list)):
raise ValueError(
"Inputs to `add_n` must be array-like, found argument "
f"of type {type(arrays)}."
)
input_vtype = arrays[0].vtype
if isinstance(input_vtype, ty.TensorType):
expected_vtype = input_vtype
expected_dtype = input_vtype.dtype
else:
raise ValueError(f"Inputs must be have vtype TensorType, found {input_vtype}.")
for array in arrays:
if array.vtype != expected_vtype:
raise ValueError(
f"Inputs must be have vtype TensorType, found {array.vtype}."
)
if array.vtype.dtype != expected_dtype:
raise ValueError(
f"Values passed to add_n must be same dtype: found {array.dtype} "
f"and {expected_dtype} in value of `arrays` argument."
)
return AddNExpression(placement=placement, inputs=arrays, vtype=input_vtype)
[docs]@_docinject_placement_arg
def identity(x, placement=None):
"""The identity operation, ``f(x) = x``.
Although the value of ``x`` is guaranteed to remain unchanged, note that its
placement may differ.
Args:
x: Input value to be returned.
Returns:
The unchanged value ``x``.
"""
placement = _materialize_placement_arg(placement)
return IdentityExpression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]@_docinject_placement_arg
def concatenate(arrays, axis=0, placement=None):
"""Concatenation of a collection of tensors along a given axis/dimension.
Note the tensors must have similar dtype and shape along all dimensions but the
given one. The dimension must already exist in the input tensors; use
:func:`expand_dims` on inputs to create a new dimnesion for concatenating along.
Args:
arrays: Tuple or List of tensors to be concatenated.
axis: Optional integer representing the dimension to concatenate along.
Dimension must already exist for all input tensors.
"""
placement = _materialize_placement_arg(placement)
if not isinstance(arrays, (tuple, list)):
raise ValueError(
"Inputs to `concatenate` must be array-like, found argument "
f"of type {type(arrays)}."
)
input_vtype = arrays[0].vtype
if isinstance(input_vtype, ty.TensorType):
expected_vtype = input_vtype
expected_dtype = input_vtype.dtype
else:
raise ValueError(f"Inputs must be have vtype TensorType, found {input_vtype}.")
for array in arrays:
if array.vtype != expected_vtype:
raise ValueError(
f"Inputs must be have vtype TensorType, found {array.vtype}."
)
if array.vtype.dtype != expected_dtype:
raise ValueError(
f"Values passed to concatenate must be same dtype: found {array.dtype} "
f"and {expected_dtype} in value of `arrays` argument."
)
return ConcatenateExpression(
placement=placement, inputs=arrays, axis=axis, vtype=input_vtype
)
[docs]@_docinject_placement_arg
def maximum(arrays, placement=None):
"""Elementwise maximum of a collection of input tensors.
Args:
arrays: Tuple or list of tensors w/ identical shape and dtype.
Returns:
A tensor containing the elementwise max of all input tensors.
"""
placement = _materialize_placement_arg(placement)
if not isinstance(arrays, (tuple, list)):
raise ValueError(
"Inputs to `concatenate` must be array-like, found argument "
f"of type {type(arrays)}."
)
input_vtype = arrays[0].vtype
if isinstance(input_vtype, ty.TensorType):
expected_vtype = input_vtype
expected_dtype = input_vtype.dtype
else:
raise ValueError(f"Inputs must be have vtype TensorType, found {input_vtype}.")
for array in arrays:
if array.vtype != expected_vtype:
raise ValueError(
f"Inputs must be have vtype TensorType, found {array.vtype}."
)
if array.vtype.dtype != expected_dtype:
raise ValueError(
f"Values passed to maximum must be same dtype: found {array.dtype} "
f"and {expected_dtype} in value of `arrays` argument."
)
return MaximumExpression(placement=placement, inputs=arrays, vtype=input_vtype)
[docs]def decrypt(key, ciphertext, placement=None):
placement = _materialize_placement_arg(placement)
# key expr typecheck
if not isinstance(key.vtype, ty.AesKeyType):
raise ValueError(
"Parameter `key` expected to be of type AesKeyType, found {key.vtype}."
)
# ciphertext expr typecheck
if not isinstance(ciphertext.vtype, ty.AesTensorType):
raise ValueError(
"Parameter `ciphertext` expected to be of type AesTensorType, "
f"found {ciphertext.vtype}."
)
# decrypt converts AesTensorType(fixed(i, f)) -> TensorType(fixed(i, f))
output_dtype = ciphertext.vtype.dtype
output_type = ty.TensorType(output_dtype)
return DecryptExpression(
placement=placement,
inputs=[key, ciphertext],
vtype=output_type,
)
[docs]@_docinject_placement_arg
def constant(value, dtype=None, vtype=None, placement=None):
"""Embed a constant of a particular type.
Args:
value: The Python value to embed as constant. Supported types include the Python
native variants of those listed in ``pymoose.computation.types``.
Tensors are embeddable as numpy ndarrays.
dtype: If given, coerces ``value`` into a Moose tensor with this dtype.
Otherwise, if ``value`` is an ndarray, the Moose dtype is inferred from the
ndarray's numpy dtype.
vtype: If given, coerces ``value`` into a Moose value with this vtype, i.e.
ValueType. Otherwise, infers the vtype from ``value``'s Python type.
Returns:
A Moose Value of type ``vtype`` embedded into the computation graph.
"""
placement = _materialize_placement_arg(placement)
vtype = _maybe_lift_dtype_to_tensor_vtype(dtype, vtype)
if isinstance(value, np.ndarray):
moose_dtype = _NUMPY_DTYPES_MAP.get(value.dtype.type, None)
if moose_dtype is None:
raise NotImplementedError(
f"Tensors of dtype `{value.dtype}` not supported as graph constants."
)
if vtype is not None and moose_dtype != vtype.dtype:
dtype = dtype or vtype.dtype
if not isinstance(dtype, dtypes.DType):
raise TypeError(
"`dtype` argument to `constant` must be of type DType, "
f"found {type(dtype)}."
)
implicit_const = constant(value, dtype=moose_dtype, placement=placement)
return cast(implicit_const, dtype, placement)
elif vtype is None:
vtype = ty.TensorType(moose_dtype)
value = values.TensorConstant(value=value)
elif isinstance(value, float):
if isinstance(vtype, ty.TensorType) and vtype.dtype.is_fixedpoint:
# want to use implicit casting, so simply wrap as ndarray and recurse
return constant(np.array(value), vtype=vtype)
value, vtype = _interpret_numeric_value(value, vtype, ty.FloatType())
elif isinstance(value, int):
if isinstance(vtype, ty.TensorType) and vtype.dtype.is_fixedpoint:
# want to use implicit casting, so simply wrap as ndarray and recurse
return constant(np.array(value), vtype=vtype)
value, vtype = _interpret_numeric_value(value, vtype, ty.IntType())
elif isinstance(value, str):
vtype = vtype or ty.StringType()
if not isinstance(vtype, ty.StringType):
raise ValueError(
"Constant value of type `str` does not match "
f"user-supplied vtype argument `{vtype}`."
)
value = values.StringConstant(value=value)
return ConstantExpression(placement=placement, inputs=[], value=value, vtype=vtype)
[docs]@_docinject_placement_arg
def add(lhs, rhs, placement=None):
"""Add two values.
Equivalent to ``lhs + rhs`` for the two inputs.
If the inputs are tensors, the addition is performed elementwise.
Args:
lhs: First addend.
rhs: Second addend.
Returns:
Sum of the two inputs.
"""
assert isinstance(lhs, Expression)
assert isinstance(rhs, Expression)
placement = _materialize_placement_arg(placement)
vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "add")
return BinaryOpExpression(
op_name="add", placement=placement, inputs=[lhs, rhs], vtype=vtype
)
[docs]@_docinject_placement_arg
def sub(lhs, rhs, placement=None):
"""Subtract two values.
Equivalent to ``lhs - rhs`` for the two inputs.
If the inputs are tensors, the subtraction is performed elementwise.
Args:
lhs: Minuend.
rhs: Subtrahend.
Returns:
Difference of the two inputs.
"""
assert isinstance(lhs, Expression)
assert isinstance(rhs, Expression)
placement = _materialize_placement_arg(placement)
vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "sub")
return BinaryOpExpression(
op_name="sub", placement=placement, inputs=[lhs, rhs], vtype=vtype
)
[docs]@_docinject_placement_arg
def mul(lhs, rhs, placement=None):
"""Multiply two values.
Equivalent to ``lhs * rhs`` for the two inputs.
If the inputs are tensors, the multiplication is performed elementwise.
Args:
lhs: First factor.
rhs: Second factor.
Returns:
Product of the two inputs.
"""
assert isinstance(lhs, Expression)
assert isinstance(rhs, Expression)
placement = _materialize_placement_arg(placement)
vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "mul")
return BinaryOpExpression(
op_name="mul", placement=placement, inputs=[lhs, rhs], vtype=vtype
)
[docs]@_docinject_placement_arg
def dot(lhs, rhs, placement=None):
"""Dot product of two tensors.
Tensor contraction along the second and first dimensions of the respective inputs.
For 1 and 2 dimensional inputs, this is equivalent to ``np.dot(lhs, rhs)``.
Args:
lhs: Left-hand tensor factor.
rhs: Right-hand tensor factor.
Returns:
Dot product of the two input tensors.
"""
assert isinstance(lhs, Expression)
assert isinstance(rhs, Expression)
placement = _materialize_placement_arg(placement)
vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "dot")
return BinaryOpExpression(
op_name="dot", placement=placement, inputs=[lhs, rhs], vtype=vtype
)
[docs]@_docinject_placement_arg
def div(lhs, rhs, placement=None):
"""Divide two tensors.
Equivalent to ``lhs / rhs`` for the two inputs.
If the inputs are tensors, the division is performed elementwise.
Args:
lhs: Dividend.
rhs: Divisor.
Returns:
Quotient of the two inputs.
"""
assert isinstance(lhs, Expression)
assert isinstance(rhs, Expression)
placement = _materialize_placement_arg(placement)
vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "div")
return BinaryOpExpression(
op_name="div", placement=placement, inputs=[lhs, rhs], vtype=vtype
)
[docs]@_docinject_placement_arg
def less(lhs, rhs, placement=None):
"""Evaluate the boolean less-than operation, i.e. ``lhs < rhs``.
If the inputs are tensors, the comparison is performed elementwise.
Args:
lhs: Left-hand side of comparison.
rhs: Right-hand side of comparison.
Returns:
The comparison of the two inputs.
"""
assert isinstance(lhs, Expression)
assert isinstance(rhs, Expression)
placement = _materialize_placement_arg(placement)
return BinaryOpExpression(
op_name="less",
placement=placement,
inputs=[lhs, rhs],
vtype=ty.TensorType(dtype=dtypes.bool_),
)
[docs]@_docinject_placement_arg
def greater(lhs, rhs, placement=None):
"""Evaluate the boolean greater-than operation, i.e. ``lhs > rhs``.
If tensors, the comparison is performed elementwise.
Args:
lhs: Left-hand side of comparison.
rhs: Right-hand side of comparison.
Returns:
The comparison of the two inputs.
"""
assert isinstance(lhs, Expression)
assert isinstance(rhs, Expression)
placement = _materialize_placement_arg(placement)
return BinaryOpExpression(
op_name="greater",
placement=placement,
inputs=[lhs, rhs],
vtype=ty.TensorType(dtype=dtypes.bool_),
)
[docs]@_docinject_placement_arg
def logical_and(lhs, rhs, placement=None):
"""Evaluate the boolean AND operation, i.e. ``lhs & rhs``.
If tensors, the operation is performed elementwise.
Args:
lhs: Left-hand side of operation.
rhs: Right-hand side of operation.
Returns:
The logical intersection of the two inputs when treated as booleans.
"""
assert isinstance(lhs, Expression)
assert isinstance(rhs, Expression)
placement = _materialize_placement_arg(placement)
vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "and")
return BinaryOpExpression(
op_name="and", placement=placement, inputs=[lhs, rhs], vtype=vtype
)
[docs]@_docinject_placement_arg
def logical_or(lhs, rhs, placement=None):
"""Evaluate the boolean OR operation, i.e. ``lhs | rhs``.
If tensors, the operation is performed elementwise.
Args:
lhs: Left-hand side of operation.
rhs: Right-hand side of operation.
Returns:
The logical union of the two inputs when treated as booleans.
"""
assert isinstance(lhs, Expression)
assert isinstance(rhs, Expression)
placement = _materialize_placement_arg(placement)
vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "or")
return BinaryOpExpression(
op_name="or", placement=placement, inputs=[lhs, rhs], vtype=vtype
)
[docs]def inverse(x, placement=None):
"""Invert a floating-point matrix.
Args:
x: A 2-dimensional float tensor to invert.
Returns:
The matrix inverse of ``x``.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
vtype = x.vtype
if not isinstance(vtype, ty.TensorType):
raise ValueError(
"`inverse` operation only supports arguments of type TensorType."
)
if vtype.dtype not in [dtypes.float32, dtypes.float64]:
raise ValueError(
"`inverse` operation only supports arguments of dtype `float32` or "
"`float64`."
)
return InverseExpression(placement=placement, inputs=[x], vtype=vtype)
[docs]@_docinject_placement_arg
def expand_dims(x, axis, placement=None):
"""Expand the rank of ``x`` with new singleton dimensions.
Args:
x: The tensor to expand.
axis: Index of the new dimension. If a tuple/list, should contain the indices
where each new dimension will be expanded.
Returns:
The expanded tensor, complete with new singleton dimensions.
"""
assert isinstance(x, Expression)
if isinstance(axis, (tuple, list)):
for ax in axis:
if not isinstance(ax, int):
raise ValueError(
"`axis` argument must be int or list/tuple of ints, found "
f"{type(ax)}"
)
elif isinstance(axis, int):
axis = [axis]
placement = _materialize_placement_arg(placement)
return ExpandDimsExpression(
placement=placement, inputs=[x], axis=axis, vtype=x.vtype
)
[docs]@_docinject_placement_arg
def squeeze(x, axis=None, placement=None):
"""Reduce out any singleton dimensions of ``x``.
Args:
x: The tensor from which to drop singleton dimensions.
axis: Optional index into the shape of ``x`` denoting which singleton dimension
to drop. If None, drops all singleton dimensions in ``x``.
Returns:
The squeezed tensor with fewer singleton dimensions.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return SqueezeExpression(placement=placement, inputs=[x], axis=axis, vtype=x.vtype)
[docs]@_docinject_placement_arg
def ones(shape, dtype, placement=None):
"""Embed a ones array into the Moose computation graph.
Equivalent to `pm.constant(np.ones(shape, dtype))` for a given ``dtype`` and
``shape``.
Args:
shape: Shape of the ones array.
dtype: Dtype of the ones array.
Returns:
A tensor of all 1s with given shape and dtype.
"""
assert isinstance(shape, Expression)
placement = _materialize_placement_arg(placement)
if isinstance(shape, (list, tuple)):
# TODO (Yann) Currently we only have the ability to declare HostShape
# as constant. We should add the ability to declare RepShape as constant.
if isinstance(placement, ReplicatedPlacementExpression):
host_placement = placement.players[0]
else:
host_placement = placement
shape = constant(
values.ShapeConstant(value=shape),
vtype=ty.ShapeType(),
placement=host_placement,
)
vtype = ty.TensorType(dtype)
return OnesExpression(placement=placement, inputs=[shape], vtype=vtype)
[docs]@_docinject_placement_arg
def zeros(shape, dtype, placement=None):
"""Embed a zeros array into the Moose computation graph.
Equivalent to `pm.constant(np.zeros(shape, dtype))` for a given ``dtype`` and
``shape``.
Args:
shape: Shape of the zeros array.
dtype: Dtype of the zeros array.
Returns:
A tensor of all 0s with given shape and dtype.
"""
assert isinstance(shape, Expression)
placement = _materialize_placement_arg(placement)
if isinstance(shape, (list, tuple)):
# TODO (Yann) Currently we only have the ability to declare HostShape
# as constant. We should add the ability to declare RepShape as constant.
if isinstance(placement, ReplicatedPlacementExpression):
host_placement = placement.players[0]
else:
host_placement = placement
shape = constant(
values.ShapeConstant(value=shape),
vtype=ty.ShapeType(),
placement=host_placement,
)
vtype = ty.TensorType(dtype)
return ZerosExpression(placement=placement, inputs=[shape], vtype=vtype)
[docs]@_docinject_placement_arg
def square(x, placement=None):
"""Square an input value.
If the input is a tensor, the operation is performed elementwise.
Args:
x: A value to square.
Returns:
The squared input.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return mul(x, x, placement=placement)
[docs]@_docinject_placement_arg
def sum(x, axis=None, placement=None):
"""Sum-reduce an input tensor.
Computes the sum of tensor elements along a particular axis, or for the entire
tensor if ``axis=None``.
Args:
x: A tensor.
axis: An optional dimension along which to sum-reduce the tensor.
If None, sum-reduces the entire tensor and outputs a scalar tensor.
Returns:
The summed input, with one or all dimensions reduced out.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return SumExpression(placement=placement, inputs=[x], axis=axis, vtype=x.vtype)
[docs]@_docinject_placement_arg
def mean(x, axis=None, placement=None):
"""Mean-reduce an input tensor.
Args:
x: A tensor.
axis: An optional dimension along which to mean-reduce the tensor.
If None, mean-reduces the entire tensor and outputs a scalar tensor.
Returns:
The averaged input, with one or all dimensions reduced out.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return MeanExpression(placement=placement, inputs=[x], axis=axis, vtype=x.vtype)
[docs]@_docinject_placement_arg
def exp(x, placement=None):
"""Elementiwise exponential function: :math:`e^x`.
If the input is a tensor, the operation is performed elementwise.
Args:
x: A value.
Returns:
The exponentiated input.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return ExpExpression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]@_docinject_placement_arg
def sqrt(x, placement=None):
"""Compute the square-root of a value :math:`\sqrt{x}`.
If the input is a tensor, the operation is performed elementwise.
Args:
x: A value.
Returns:
The square-root of the input.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return SqrtExpression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]@_docinject_placement_arg
def sigmoid(x, placement=None):
r"""Apply the sigmoid function, :math:`\frac{1}{1 + e^{-x}}`.
If the input is a tensor, the operation is performed elementwise.
Args:
x: A value.
Returns:
The result of applying the sigmoid function to ``x``.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return SigmoidExpression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]@_docinject_placement_arg
def relu(x, placement=None):
"""Apply the rectified-linear unit (ReLU) function, ``f(x) = max(0, x)``.
If the input is a tensor, the operation is performed elementwise.
Args:
x: A value.
Returns:
A value correspoding to ``ReLU(x)``.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return ReluExpression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]@_docinject_placement_arg
def softmax(x, axis, upmost_index, placement=None):
r"""Softmax function.
.. math ::
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
Args:
x: A tensor.
axis: The dimension along which the softmax divisor's sum-reduce should be
computed.
upmost_index: The max index that should be used for computing the softmax
divisor's sum-reduce. Generally, this should be the size of the ``axis``
dimension of ``x``.
Returns:
Result of applying the softmax function on ``x`` along the given ``axis``.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return SoftmaxExpression(
placement=placement,
inputs=[x],
axis=axis,
upmost_index=upmost_index,
vtype=x.vtype,
)
[docs]@_docinject_placement_arg
def argmax(x, axis, upmost_index, placement=None):
"""Compute the index of the maximal element of a tensor along a given dimension.
Args:
x: A tensor.
axis: The dimension along which to compute the argmax.
upmost_index.
upmost_index: The max index that should be considered for computing the argmax.
Generally, this should be the size of the ``axis`` dimension of ``x``.
Returns:
A dimension-reduced tensor representing the argmax of ``x`` along ``axis``.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return ArgmaxExpression(
placement=placement,
inputs=[x],
axis=axis,
upmost_index=upmost_index,
vtype=x.vtype,
)
[docs]@_docinject_placement_arg
def log(x, placement=None):
"""Compute the elementwise natural logarithm of a tensor.
Args:
x: A tensor.
Returns:
The tensor representing ``log(x)``.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return LogExpression(
placement=placement,
inputs=[x],
vtype=x.vtype,
)
[docs]@_docinject_placement_arg
def log2(x, placement=None):
"""Compute the elementwise base-2 logarithm of a tensor.
Args:
x: A tensor.
Returns:
The tensor representing ``log_2(x)``.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return Log2Expression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]@_docinject_placement_arg
def shape(x, placement=None):
"""Compute the shape of an input tensor.
Args:
x: A tensor.
Returns:
A Moose value of type ``ShapeType`` corresponding to the input tensor's shape.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return ShapeExpression(placement=placement, inputs=[x], vtype=ty.ShapeType())
[docs]@_docinject_placement_arg
def index_axis(x, axis, index, placement=None):
"""Index a tensor along a given dimension.
Args:
x: A tensor.
axis: The dimension along which to index.
Returns:
A slice of ``x`` corresponding to the sub-tensor at ``index`` along ``axis``.
"""
assert isinstance(x, Expression)
if not isinstance(axis, int) or index < 0:
raise ValueError(
"`axis` argument must be int greater or equal to 0, found "
f"{axis} of type {type(axis)}"
)
if not isinstance(index, int) or index < 0:
raise ValueError(
"`index` argument must be int greater or equal to 0, found "
f"{index} of type {type(index)}"
)
placement = _materialize_placement_arg(placement)
return IndexAxisExpression(
placement=placement, inputs=[x], axis=axis, index=index, vtype=x.vtype
)
[docs]@_docinject_placement_arg
def select(x, axis, index, placement=None):
"""Select elements along some axis of a tensor according to some index tensor.
Args:
x: A tensor.
axis: The dimension along which to index.
index: A 1-d boolean tensor such that ``len(index) == x.shape[axis]``.
Returns:
Copy of the tensor ``x``, with the values along `axis` filtered such that
``x[..., i, ...]`` is kept iff ``index[i] == 1``.
"""
# TODO (Yann) extend kernels to support tuple of ints for axis
# and multiple index
assert isinstance(x, Expression)
assert isinstance(index, Expression)
if not isinstance(axis, int):
raise ValueError(
"`axis` argument must be int greater or equal to 0, found "
f"{axis} of type {type(axis)}"
)
placement = _materialize_placement_arg(placement)
return SelectExpression(
placement=placement, inputs=[x, index], axis=axis, vtype=x.vtype
)
[docs]@_docinject_placement_arg
def sliced(x, begin, end, placement=None):
"""Compute a slice of a value.
Args:
x: A slice-able value, e.g. a TensorType or ShapeType.
begin: Index to start the slice at, e.g. 1 in ``x[1:5]``.
end: Non-inclusive index to end the slice at, e.g. 5 in ``x[1:5]``.
Returns:
The sliced value.
"""
assert isinstance(x, Expression)
assert isinstance(begin, int)
assert isinstance(end, int)
placement = _materialize_placement_arg(placement)
return SliceExpression(
placement=placement, inputs=[x], begin=begin, end=end, vtype=x.vtype
)
# TODO(jvmncs): better docstring
[docs]@_docinject_placement_arg
def strided_slice(x, slices, placement=None):
"""Compute a strided slice of a value.
This version is more general than ``sliced``, as it has support for more complex
variants of Python slices.
Args:
x: A value.
slices: A list/tuple of Python `slice` objects.
Returns:
The sliced value.
"""
assert isinstance(x, Expression)
assert isinstance(slices, (tuple, list))
placement = _materialize_placement_arg(placement)
for s in slices:
if not isinstance(s, slice):
raise ValueError(
"`slices` argument must a list/tuple of slices, found " f"{type(s)}"
)
return StridedSliceExpression(
placement=placement, inputs=[x], slices=slices, vtype=x.vtype
)
[docs]@_docinject_placement_arg
def transpose(x, placement=None):
"""Compute the transpose of a tensor.
This is equivalent to numpy.ndarray.T.
Args:
x: A tensor.
Returns:
The input tensor with dimensions in reverse order.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return TransposeExpression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]def atleast_2d(x, to_column_vector=False, placement=None):
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return AtLeast2DExpression(
placement=placement,
inputs=[x],
to_column_vector=to_column_vector,
vtype=x.vtype,
)
[docs]@_docinject_placement_arg
def reshape(x, shape, placement=None):
"""Reshape a tensor.
Broadcasting is not allowed; the new shape must have the same total number of
elements.
Args:
x: A tensor.
shape: A list, tuple, or Shape dictating the new shape of the tensor.
Returns:
The reshaped tensor.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
if isinstance(shape, (list, tuple)):
# TODO (Yann) Currently we only have the ability to declare HostShape
# as constant. We should add the ability to declare RepShape as constant.
if isinstance(placement, ReplicatedPlacementExpression):
host_placement = placement.players[0]
else:
host_placement = placement
shape = constant(
values.ShapeConstant(value=shape),
vtype=ty.ShapeType(),
placement=host_placement,
)
assert isinstance(shape, Expression)
return ReshapeExpression(placement=placement, inputs=[x, shape], vtype=x.vtype)
[docs]@_docinject_placement_arg
def abs(x, placement=None):
"""Compute the absolute value.
If the input is a tensor. the operation is performed elementwise.
Args:
x: A value
Returns:
The absolute value of ``x``.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
return AbsExpression(placement=placement, inputs=[x], vtype=x.vtype)
[docs]@_docinject_placement_arg
def mux(selector, x, y, placement=None):
"""Multiplex two tensors according to some condition.
This op allows for static control-flow of elements coming from two tensors.
For boolean tensor ``s`` and arbitrary tensors ``x`` and ``y``, the operation
is equivalent to ``s * (x - y) + y``.
Args:
selector: Boolean tensor representing the control-flow condition.
x: A tensor to fill from when the selector condition is 1.
y: A tensor to fill from when the selector condition is 0.
Result:
A tensor with elements of ``x`` and ``y`` multiplexed according to the condition
given by ``selector``.
"""
assert isinstance(selector, Expression)
assert isinstance(selector.vtype, ty.TensorType)
assert selector.vtype.dtype.is_boolean, selector.vtype.dtype
assert isinstance(x, Expression)
assert isinstance(x.vtype, ty.TensorType), x.vtype
assert x.vtype.dtype.is_fixedpoint, x.vtype.dtype
assert isinstance(y, Expression)
assert isinstance(y.vtype, ty.TensorType), y.vtype
assert y.vtype.dtype.is_fixedpoint, y.vtype.dtype
placement = _materialize_placement_arg(placement)
assert isinstance(placement, ReplicatedPlacementExpression)
vtype = _assimilate_arg_vtypes(x.vtype, y.vtype, "mux")
return MuxExpression(placement=placement, inputs=[selector, x, y], vtype=vtype)
[docs]@_docinject_placement_arg
def cast(x, dtype, placement=None):
"""Cast a tensor to a new dtype.
Args:
x: A tensor.
dtype: A DType.
Returns:
The tensor ``x`` converted to ``dtype``.
"""
assert isinstance(x, Expression)
placement = _materialize_placement_arg(placement)
if not isinstance(x.vtype, ty.TensorType):
raise ValueError(
f"Argument to `cast` operation must be tensor, found {x.vtype}."
)
# Check dtype args are well-defined
if x.vtype.dtype is None:
raise ValueError(
"Argument to `cast` function must have well-defined dtype; "
"found value with dtype=None."
)
elif dtype is None:
raise ValueError(
"Invalid `dtype` argument to `cast` function: cannot cast to dtype=None."
)
# Ensure value can be cast by compiler/executor into the well-defined dtype arg
if isinstance(dtype, dtypes.DType):
moose_dtype = dtype
elif dtype in _NUMPY_DTYPES_MAP:
moose_dtype = _NUMPY_DTYPES_MAP[dtype]
else:
raise ValueError(
"Unsupported dtype arg in `cast` function: expected argument "
f"of type DType, found type {type(dtype)}."
)
if x.vtype.dtype == moose_dtype:
# This is a no-op
return x
return CastExpression(
placement=placement, inputs=[x], vtype=ty.TensorType(moose_dtype)
)
[docs]@_docinject_placement_arg
def load(key, query="", dtype=None, vtype=None, placement=None):
"""Load a value from placement storage.
The underlying key-value store implementation is runtime-specific. Generally, assume
that each HostPlacement can be using a different storage implementation in its Moose
worker/executor.
Args:
key: A string or Moose String corresponding to the value that should be loaded.
query: An optional query string/String to provide to executor storage.
Most common storage implementations ignore this.
dtype: If value should be loaded as a tensor, the DType to coerce the tensor to.
If None, inferred from the value's numpy dtype.
vtype: The Moose type to coerce the loaded value into. If None, will be traced
as :class:`~pymoose.computation.types.UnknownType` and the compiler will
attempt to fill it in during its initial Typing pass.
Returns:
The loaded value, as provided by the worker backing the HostPlacement.
"""
placement = _materialize_placement_arg(placement)
vtype = _maybe_lift_dtype_to_tensor_vtype(dtype, vtype)
if isinstance(key, str):
key = constant(key, placement=placement, vtype=ty.StringType())
elif isinstance(key, Argument) and key.vtype not in [ty.StringType(), None]:
raise ValueError(
f"Function 'edsl.load' encountered `key` argument of vtype {key.vtype}; "
"expected `StringType`."
)
elif not isinstance(key, Expression):
raise ValueError(
f"Function 'edsl.load' encountered `key` argument of type {type(key)}; "
"expected one of: string, ConstantExpression, or Argument."
)
if isinstance(query, str):
query = constant(query, placement=placement, vtype=ty.StringType())
elif isinstance(query, Argument) and query.vtype not in [ty.StringType(), None]:
raise ValueError(
f"Function 'edsl.load' encountered `query` argument of "
f"vtype {query.vtype}; expected 'StringType'."
)
elif not isinstance(query, Expression):
raise ValueError(
f"Function 'edsl.load' encountered `query` argument of type {type(query)}; "
"expected one of: string, ConstantExpression, or Argument."
)
return LoadExpression(placement=placement, inputs=[key, query], vtype=vtype)
[docs]@_docinject_placement_arg
def save(key, value, placement=None):
"""Save a key-value pair to placement storage.
The underlying key-value store implementation is runtime-specific. Generally, assume
that each HostPlacement can be using a different storage implementation in its Moose
worker/executor.
Args:
key: A string or Moose String.
value: A Moose Value.
Returns:
A Moose Value of type Unit.
"""
assert isinstance(value, Expression)
placement = _materialize_placement_arg(placement)
if isinstance(key, str):
key = constant(key, placement=placement, vtype=ty.StringType())
elif isinstance(key, Argument) and key.vtype not in [ty.StringType(), None]:
raise ValueError(
f"Function 'edsl.save' encountered `key` argument of type {key.vtype}; "
"expected 'StringType'."
)
elif not isinstance(key, Expression):
raise ValueError(
f"Function 'edsl.save' encountered `key` argument of type {type(key)}; "
"expected one of: string, ConstantExpression, or ArgumentExpression."
)
return SaveExpression(placement=placement, inputs=[key, value], vtype=None)
[docs]@_docinject_placement_arg
def output(tag, value, placement=None):
"""Tag an output of a computation.
This op is similar to ``identity``, but additionally tags its value. It can be used
to pin an output to a particular placement without writing it to placement storage.
It's also useful for maintaining the ordering of computation outputs in PyMoose,
since Moose compilation generally doesn't preserve output order.
Args:
tag: A tag to associate with the output, useful for reconstructing the original
order of outputs.
value: The Moose value to output.
Returns:
The tagged value.
"""
assert isinstance(value, Expression)
assert isinstance(tag, str)
placement = _materialize_placement_arg(placement)
return OutputExpression(
placement=placement, inputs=[value], vtype=value.vtype, tag=tag
)
[docs]def computation(func=None, role_map=None):
"""Annotates a Python function as a Moose computation.
Args:
func: A Callable.
role_map: A map of abstract placements to identities in the current runtime
context.
Returns:
An abstract Moose computation that can be invoked in a runtime context.
"""
if func is None:
return ft.partial(computation, role_map=role_map)
return AbstractComputation(func, role_map)
[docs]class AbstractComputation:
def __init__(self, func, role_map):
if not callable(func):
raise TypeError(
f"Argument `func` should be a callable, but found {type(func)}."
)
self.func = func
if role_map is not None and not isinstance(role_map, dict):
raise TypeError(
"Argument `role_map` should be map of placement names to placement "
f"names, found {type(role_map)}."
)
self.role_map = role_map
def __call__(self, *args, **kwargs):
func_signature = inspect.signature(self.func)
arg_names = [arg_name for arg_name, _ in func_signature.parameters.items()]
arguments = {}
# add values from `args`
for arg_i, arg_val in enumerate(args):
if arg_i >= len(arg_names):
raise ValueError(f"Too many arguments for `{self.func.__name__}`")
arg_name = arg_names[arg_i]
arguments[arg_name] = arg_val
# add values from `kwargs`
for arg_name, arg_val in kwargs.items():
if arg_name in arguments:
raise ValueError(
f"Argument `{arg_name}` given more than once to "
f"`{self.func.__name__}`"
)
arguments[arg_name] = arg_val
# check that all arguments were given
for arg_name in arg_names:
if arg_name not in arguments:
raise ValueError(
f"Missing argument `{arg_name}` in call to `{self.func.__name__}`"
)
# check that no extra arguments were given
# NOTE we could potentially leave out this check
for arg_name in arguments.keys():
if arg_name not in arg_names:
raise ValueError(
f"Argument `{arg_name}` is not used by `{self.func.__name__}`"
)
runtime = get_current_runtime()
if not runtime:
raise RuntimeError("No default runtime found")
return runtime.evaluate_computation(self, arguments)
[docs] def with_role_map(self, role_map):
return self.__class__(self.func, role_map)
def _assimilate_arg_dtypes(lhs_vtype, rhs_vtype, fn_name):
lhs_dtype = lhs_vtype.dtype
rhs_dtype = rhs_vtype.dtype
if lhs_dtype != rhs_dtype:
raise ValueError(
f"Function `{fn_name}` expected arguments of similar dtype: "
f"found mismatched dtypes `{lhs_dtype}` and `{rhs_dtype}`."
)
return lhs_vtype
def _assimilate_arg_vtypes(lhs_vtype, rhs_vtype, fn_name):
if isinstance(lhs_vtype, ty.TensorType) and isinstance(rhs_vtype, ty.TensorType):
return _assimilate_arg_dtypes(lhs_vtype, rhs_vtype, fn_name)
if lhs_vtype != rhs_vtype:
raise ValueError(
f"Function `{fn_name}` expected arguments of similar type: "
f"found mismatched types `{lhs_vtype}` and `{rhs_vtype}`."
)
return lhs_vtype
def _check_tensor_type_arg_consistency(dtype, vtype):
if isinstance(vtype, ty.TensorType) and vtype.dtype != dtype:
raise ValueError(
f"Inconsistent type information for tensor: dtype {dtype} is "
f"inconsistent with tensor type {vtype}."
)
def _materialize_placement_arg(plc):
plc = plc or get_current_placement()
if not isinstance(plc, PlacementExpression):
raise TypeError(f"Expected value of type Placement, found {type(plc)}.")
return plc
def _maybe_lift_dtype_to_tensor_vtype(dtype, vtype):
if dtype is None and vtype is None:
return
elif vtype is None and dtype is not None:
return ty.TensorType(dtype)
elif vtype is not None and dtype is not None:
_check_tensor_type_arg_consistency(dtype, vtype)
return vtype
else: # vtype but no dtype
return vtype
def _interpret_numeric_value(value, vtype, fallback_vtype):
assert isinstance(value, (int, float))
if vtype is None:
vtype = fallback_vtype
if isinstance(vtype, ty.TensorType):
dtype = vtype.dtype
if not dtype.is_float and not dtype.is_integer:
raise TypeError(f"Cannot interpret scalar constant as dtype {dtype}.")
value = values.TensorConstant(np.array(value, dtype=dtype.numpy_dtype))
elif isinstance(vtype, ty.FloatType):
value = values.FloatConstant(value)
elif isinstance(vtype, ty.IntType):
value = values.IntConstant(value)
else:
raise TypeError(
"Cannot interpret numeric constant as non-numeric type {vtype}."
)
return value, vtype