Source code for pymoose.predictors.predictor_utils
from pymoose.computation import dtypes
DEFAULT_FLOAT_DTYPE = dtypes.float64
DEFAULT_FIXED_DTYPE = dtypes.fixed(24, 40)
[docs]def find_attribute_in_node(node, attribute_name, enforce=True):
node_attr = None
for attr in node.attribute:
if attr.name == attribute_name:
node_attr = attr
if enforce and node_attr is None:
raise ValueError(
f"Node {node.name} does not contain attribute {attribute_name}."
)
return node_attr
[docs]def find_node_in_model_proto(model_proto, operator_name, enforce=True):
node = None
for operator in model_proto.graph.node:
if operator.name == operator_name:
node = operator
if enforce and node is None:
raise ValueError(f"Model proto does not contain operator {operator_name}.")
return node
[docs]def find_initializer_in_model_proto(model_proto, operator_name, enforce=True):
initializer = None
for operator in model_proto.graph.initializer:
if operator.name == operator_name:
initializer = operator
if enforce and initializer is None:
raise ValueError(f"Model proto does not contain operator {operator_name}.")
return initializer, initializer.dims
[docs]def find_activation_in_model_proto(model_proto, operator_name, enforce=True):
activation = None
for operator in model_proto.graph.node:
if operator.output[0] == operator_name:
activation = operator.name
if enforce and activation is None:
raise ValueError(f"Model proto does not contain operator {operator_name}.")
return activation
[docs]def find_parameters_in_model_proto(model_proto, operator_names, enforce=True):
parameters = []
for operator in model_proto.graph.initializer:
for operator_name in operator_names:
if operator_name in operator.name:
parameters.append(operator)
if enforce and len(parameters) == 0:
raise ValueError(f"Model proto does not contain operator {operator_name}.")
return parameters
[docs]def find_op_types_in_model_proto(model_proto, enforce=True):
operations = []
for node in model_proto.graph.node:
operations.append(node.op_type)
if enforce and len(operations) == 0:
raise ValueError("Model proto nodes do not contain op_type.")
return operations
[docs]def find_output_in_model_proto(model_proto, enforce=True):
output_dim = None
output = model_proto.graph.output
output_dim = output.type.tensor_type.shape.dim
if enforce and output_dim is None:
raise ValueError("Model proto does not contain output dimention.")
return output_dim