Source code for pymoose.predictors.predictor

import abc

import pymoose as pm
from pymoose.predictors import predictor_utils as utils


[docs]class Predictor(metaclass=abc.ABCMeta): """Base class for Moose Predictor interface.""" def __init__(self): ( (self.alice, self.bob, self.carole), self.mirrored, self.replicated, ) = self._standard_replicated_placements()
[docs] @classmethod def fixedpoint_constant(cls, x, plc=None, dtype=utils.DEFAULT_FIXED_DTYPE): """Convenience method for embedding a constant with fixedpoint dtype.""" x = pm.constant(x, dtype=pm.float64, placement=plc) return pm.cast(x, dtype=dtype, placement=plc)
[docs] @classmethod def handle_output( cls, prediction, prediction_handler, output_dtype=utils.DEFAULT_FLOAT_DTYPE ): """Pins a value to an output placement, casting into a specified dtype.""" with prediction_handler: result = pm.cast(prediction, dtype=output_dtype) return result
@property def host_placements(self): return self.alice, self.bob, self.carole def _standard_replicated_placements(self): """Standard set of abstract placements needed for replicated computations.""" alice = pm.host_placement("alice") bob = pm.host_placement("bob") carole = pm.host_placement("carole") replicated = pm.replicated_placement( name="replicated", players=[alice, bob, carole] ) mirrored = pm.mirrored_placement(name="mirrored", players=[alice, bob, carole]) return (alice, bob, carole), mirrored, replicated
[docs]def AesWrapper(inner_model_cls): class AesPredictor(inner_model_cls): """Predictor extension that adds methods dealing with AesTensor inputs.""" def __call__(self, fixedpoint_dtype=utils.DEFAULT_FIXED_DTYPE): return self.aes_predictor_factory(fixedpoint_dtype) @classmethod def handle_aes_input(cls, aes_key, aes_data, decryptor): """Convenience method for decrypting AES inputs on a given placement.""" assert isinstance(aes_data.vtype, pm.AesTensorType) assert aes_data.vtype.dtype.is_fixedpoint assert isinstance(aes_key.vtype, pm.AesKeyType) with decryptor: aes_inputs = pm.decrypt(aes_key, aes_data) return aes_inputs def aes_predictor_factory(self, fixedpoint_dtype=utils.DEFAULT_FIXED_DTYPE): """Wraps a class's predictor_fn with replicated AES decryption of inputs.""" @pm.computation def predictor( aes_data: pm.Argument( self.alice, vtype=pm.AesTensorType(dtype=fixedpoint_dtype) ), aes_key: pm.Argument(self.replicated, vtype=pm.AesKeyType()), ): x = self.handle_aes_input(aes_key, aes_data, decryptor=self.replicated) with self.replicated: pred = self.predictor_fn(x, fixedpoint_dtype) return self.handle_output(pred, prediction_handler=self.bob) return predictor return AesPredictor