import abc
import pymoose as pm
from pymoose.predictors import predictor
from pymoose.predictors import predictor_utils as utils
[docs]class DecisionTreeRegressor(predictor.Predictor):
def __init__(self, weights, children, split_conditions, split_indices):
super().__init__()
self.weights = weights
self.left, self.right = children
self.split_conditions = split_conditions
self.split_indices = split_indices
[docs] @classmethod
def from_json(cls, tree_json):
weights = dict(enumerate(tree_json["base_weights"]))
left = _map_json_to_onnx_leaves(tree_json["left_children"])
right = _map_json_to_onnx_leaves(tree_json["right_children"])
split_conditions = tree_json["split_conditions"]
split_indices = tree_json["split_indices"]
return cls(weights, (left, right), split_conditions, split_indices)
[docs] def aes_predictor_factory(self):
raise NotImplementedError(
f"{self.__class__.__name__} is not meant to be used directly as an "
"AesPredictor model. Consider expressing your decision tree as a tree "
"ensemble with another AesPredictor implementation."
)
def __call__(self, x, n_features, rescale_factor, fixedpoint_dtype):
leaf_weights = {ix: rescale_factor * w for ix, w in self.weights.items()}
features_vec = [pm.index_axis(x, axis=1, index=i) for i in range(n_features)]
return self._traverse_tree(0, leaf_weights, features_vec, fixedpoint_dtype)
def _traverse_tree(self, node, leaf_weights, x_features, fixedpoint_dtype):
left_child = self.left[node]
right_child = self.right[node]
if left_child != 0 and right_child != 0:
# we're at an inner node; this is the recursive case
selector = pm.less(
x_features[self.split_indices[node]],
self.fixedpoint_constant(
self.split_conditions[node], self.mirrored, dtype=fixedpoint_dtype
),
)
return pm.mux(
selector,
self._traverse_tree(
left_child, leaf_weights, x_features, fixedpoint_dtype
),
self._traverse_tree(
right_child, leaf_weights, x_features, fixedpoint_dtype
),
)
else:
assert left_child == 0
assert right_child == 0
return self.fixedpoint_constant(leaf_weights[node], self.carole)
[docs]class TreeEnsemble(predictor.Predictor, metaclass=abc.ABCMeta):
def __init__(self, trees, n_features, base_score, learning_rate):
super().__init__()
self.n_features = n_features
self.trees = trees
self.base_score = base_score
self.learning_rate = learning_rate
[docs] @classmethod
@abc.abstractmethod
def from_onnx(cls, model_proto):
pass
[docs] @abc.abstractmethod
def post_transform(self, tree_scores, fixedpoint_dtype):
pass
[docs] def predictor_fn(self, x, fixedpoint_dtype):
forest_scores = [
tree(
x,
self.n_features,
rescale_factor=self.learning_rate,
fixedpoint_dtype=fixedpoint_dtype,
)
for tree in self.trees
]
# if any of the trees are degenerate, they will return a non-replicated value.
# we want post_transform to expect a collection of replicated values, since its
# variadic ops will not necessarily know to move their results from source
# placements to replicated placement.
# it's a bit ugly, but it works for now.
return list(map(pm.identity, forest_scores))
def __call__(self, x, fixedpoint_dtype=utils.DEFAULT_FIXED_DTYPE):
tree_scores = self.predictor_fn(x, fixedpoint_dtype=fixedpoint_dtype)
return self.post_transform(tree_scores, fixedpoint_dtype=fixedpoint_dtype)
[docs]class TreeEnsembleClassifier(TreeEnsemble):
"""Tree ensemble classification Predictor for GBTs and Random Forests.
This class can be used for binary, multiclass, or multilabel classification. Support
for multiclass classification uses the one-vs-rest method.
Args:
trees: Nested collection of :class:`~DecisionTreeRegressor`s.
n_features: Number of features expected for input data.
n_classes: Number of output classes.
base_score: The base score for the underlying tree ensemble model, similar to a
bias/intercept term.
learning_rate: Learning rate parameter used to re-scale leaf weights in the
model trees.
transform_output: Boolean determining whether a softmax should be applied to
derive probabilities from the tree ensemble output.
tree_class_map: Dictionary mapping ``trees`` indices to class indices. Keeps
track of which trees in ``trees`` correspond to each class in the
one-vs-rest formulation of multiclass classification.
"""
def __init__(
self,
trees,
n_features,
n_classes,
base_score,
learning_rate,
transform_output,
tree_class_map,
):
super().__init__(trees, n_features, base_score, learning_rate)
self.n_classes = n_classes
self.tree_class_map = tree_class_map
self.transform_output = transform_output
[docs] @classmethod
def from_onnx(cls, model_proto):
"""Construct a TreeEnsembleClassifier from a parsed ONNX model.
Args:
model_proto: An ONNX ModelProto containing a TreeEnsembleClassifier node.
Returns:
A TreeEnsembleClassifier built from the parameters and configuration of the
given ONNX model.
Raises:
ValueError if ONNX graph is missing expected nodes.
"""
(
forest_node,
(nodes_treeids, left, right, split_conditions, split_indices),
n_trees,
n_features,
base_score,
learning_rate,
) = _onnx_base(model_proto, "TreeEnsembleClassifier")
class_ids_attr = utils.find_attribute_in_node(forest_node, "class_ids")
assert class_ids_attr.type == 7
class_ids = class_ids_attr.ints
class_nodeids_attr = utils.find_attribute_in_node(forest_node, "class_nodeids")
assert class_nodeids_attr.type == 7
class_nodeids = class_nodeids_attr.ints
class_treeids_attr = utils.find_attribute_in_node(forest_node, "class_treeids")
assert class_treeids_attr.type == 7
class_treeids = class_treeids_attr.ints
class_weights_attr = utils.find_attribute_in_node(forest_node, "class_weights")
assert class_weights_attr.type == 6
class_weights = class_weights_attr.floats
classlabels_ints = utils.find_attribute_in_node(
forest_node, "classlabels_int64s", enforce=False
)
classlabels_strings = utils.find_attribute_in_node(
forest_node, "classlabels_strings", enforce=False
)
assert classlabels_ints is not None or classlabels_strings is not None
if classlabels_ints is not None:
classlabels = classlabels_ints.ints
elif classlabels_strings is not None:
classlabels = classlabels_strings.strings
n_classes = len(classlabels)
post_transform_attr = utils.find_attribute_in_node(
forest_node, "post_transform"
)
post_transform = post_transform_attr.s.decode()
if post_transform == "NONE" and n_classes > 2:
# in this case, sklearn's ONNX file stores nodes differently;
# each leaf & inner node has array of length n_classes instead of
# having n_trees * n_classes separate trees, whereas other ONNX
# files have separate trees per class.
# in TreeEnsembleClassifier, we currently always represent with a separate
# forest per class, so here we need to duplicate some trees for that
# representation.
final_class_treeids = [
class_id + tree_id * n_classes
for (tree_id, class_id) in zip(class_treeids, class_ids)
]
# update n_trees inferred by onnx helper fn above
n_trees = len(set(final_class_treeids))
# rely on nodes_treeids being sorted to preserve sublist order.
# the order matters to map back into the format expected when there are
# separate forests for each class
assert nodes_treeids == sorted(nodes_treeids)
sublists = [
list(filter(lambda x: x == i, nodes_treeids))
for i in sorted(set(nodes_treeids))
]
repeated_sublists = [
[n_classes * i + j for _ in x]
for j in range(n_classes)
for i, x in enumerate(sublists)
]
final_nodes_treeids = [x for y in repeated_sublists for x in y]
else:
final_class_treeids = class_treeids
final_nodes_treeids = nodes_treeids
tree_args = [
{
"weights": {},
"children": [[], []],
"split_indices": [],
"split_conditions": [],
}
for _ in range(n_trees)
]
for i, tree_id in enumerate(final_nodes_treeids):
# i % len(_) duplicates nodes from the same ONNX trees in cases when
# final_nodes_treeids is longer than the lists of nodes coming from ONNX
# this is only the case when there are not n_trees * n_classes distinct
# trees in the ONNX file
tree_args[tree_id]["children"][0].append(left[i % len(left)])
tree_args[tree_id]["children"][1].append(right[i % len(right)])
tree_args[tree_id]["split_indices"].append(
split_indices[i % len(split_indices)]
)
tree_args[tree_id]["split_conditions"].append(
split_conditions[i % len(split_conditions)]
)
for i, class_weight in enumerate(class_weights):
tree_args[final_class_treeids[i]]["weights"][
class_nodeids[i]
] = class_weight
trees = [DecisionTreeRegressor(**kwargs) for kwargs in tree_args]
tree_class_map = {
tree_id: class_id
for tree_id, class_id in zip(final_class_treeids, class_ids)
}
transform_output = post_transform != "NONE"
return cls(
trees,
n_features,
n_classes,
base_score,
learning_rate,
transform_output,
tree_class_map,
)
[docs] def post_transform(self, tree_scores, fixedpoint_dtype):
if self.n_classes == 2:
return self._maybe_sigmoid(tree_scores, fixedpoint_dtype)
else:
logit = self._ovr_logit(
tree_scores, axis=1, fixedpoint_dtype=fixedpoint_dtype
)
if self.transform_output:
return pm.softmax(logit, axis=1, upmost_index=self.n_classes)
return logit
def _maybe_sigmoid(self, tree_scores, fixedpoint_dtype):
base_score = self.fixedpoint_constant(
self.base_score, self.carole, dtype=fixedpoint_dtype
)
logit = pm.add(pm.add_n(tree_scores), base_score)
pos_prob = pm.sigmoid(logit) if self.transform_output else logit
pos_prob = pm.expand_dims(pos_prob, axis=1)
one = self.fixedpoint_constant(1, plc=self.mirrored, dtype=fixedpoint_dtype)
neg_prob = pm.sub(one, pos_prob)
return pm.concatenate([neg_prob, pos_prob], axis=1)
def _ovr_logit(self, tree_scores, axis, fixedpoint_dtype):
ovr_results = [[] for _ in range(self.n_classes)]
for tree_ix, model_ix in self.tree_class_map.items():
ovr_results[model_ix].append(tree_scores[tree_ix])
base_score = self.fixedpoint_constant(
self.base_score, self.carole, dtype=fixedpoint_dtype
)
ovr_logits = [pm.add(pm.add_n(ovr), base_score) for ovr in ovr_results]
reformed_logits = pm.concatenate(
[pm.expand_dims(ovr, axis=axis) for ovr in ovr_logits], axis=axis
)
return reformed_logits
[docs]class TreeEnsembleRegressor(TreeEnsemble):
"""Tree ensemble regression Predictor, accommodating both GBTs and Random Forests.
Args:
trees: Nested collection of :class:`~DecisionTreeRegressor`s.
n_features: Number of features expected for input data.
base_score: The base score for the underlying tree ensemble model, similar to a
bias/intercept term.
learning_rate: Learning rate parameter used to re-scale leaf weights in the
model trees.
"""
[docs] @classmethod
def from_onnx(cls, model_proto):
"""Construct a TreeEnsembleRegressor from a parsed ONNX model.
Args:
model_proto: An ONNX ModelProto containing a TreeEnsembleRegressor node.
Returns:
A TreeEnsembleRegressor built from the parameters and configuration of the
given ONNX model.
Raises:
ValueError if ONNX graph is missing expected nodes.
"""
(
forest_node,
(nodes_treeids, left, right, split_conditions, split_indices),
n_trees,
n_features,
base_score,
learning_rate,
) = _onnx_base(model_proto, "TreeEnsembleRegressor")
target_nodeids_attr = utils.find_attribute_in_node(
forest_node, "target_nodeids"
)
assert target_nodeids_attr.type == 7
target_nodeids = target_nodeids_attr.ints
target_treeids_attr = utils.find_attribute_in_node(
forest_node, "target_treeids"
)
assert target_treeids_attr.type == 7
target_treeids = target_treeids_attr.ints
target_weights_attr = utils.find_attribute_in_node(
forest_node, "target_weights"
)
assert target_weights_attr.type == 6 # FLOATS
target_weights = target_weights_attr.floats
tree_args = [
{
"weights": {},
"children": [[], []],
"split_indices": [],
"split_conditions": [],
}
for _ in range(n_trees)
]
for i, tree_id in enumerate(nodes_treeids):
tree_args[tree_id]["children"][0].append(left[i])
tree_args[tree_id]["children"][1].append(right[i])
tree_args[tree_id]["split_indices"].append(split_indices[i])
tree_args[tree_id]["split_conditions"].append(split_conditions[i])
for i, tree_id in enumerate(target_treeids):
tree_args[tree_id]["weights"][target_nodeids[i]] = target_weights[i]
trees = [DecisionTreeRegressor(**kwargs) for kwargs in tree_args]
return cls(trees, n_features, base_score, learning_rate)
[docs] def post_transform(self, tree_scores, fixedpoint_dtype):
base_score = self.fixedpoint_constant(
self.base_score, self.carole, dtype=fixedpoint_dtype
)
penultimate_score = pm.add_n(tree_scores)
return pm.add(base_score, penultimate_score)
def _map_json_to_onnx_leaves(json_leaves):
return [0 if child == -1 else child for child in json_leaves]
def _onnx_base(model_proto, forest_node_name):
forest_node = utils.find_node_in_model_proto(
model_proto, forest_node_name, enforce=False
)
if forest_node is None:
raise ValueError(
"Incompatible ONNX graph provided: graph must contain a "
f"{forest_node_name} operator."
)
# construct `tree_args` for `trees` argument
nodes_treeids_attr = utils.find_attribute_in_node(forest_node, "nodes_treeids")
assert nodes_treeids_attr.type == 7 # INTS
nodes_treeids = nodes_treeids_attr.ints
left_attr = utils.find_attribute_in_node(forest_node, "nodes_truenodeids")
assert left_attr.type == 7
left = left_attr.ints
right_attr = utils.find_attribute_in_node(forest_node, "nodes_falsenodeids")
assert right_attr.type == 7
right = right_attr.ints
split_conditions_attr = utils.find_attribute_in_node(forest_node, "nodes_values")
assert split_conditions_attr.type == 6
split_conditions = split_conditions_attr.floats
split_indices_attr = utils.find_attribute_in_node(forest_node, "nodes_featureids")
assert split_indices_attr.type == 7
split_indices = split_indices_attr.ints
tree_args = (nodes_treeids, left, right, split_conditions, split_indices)
n_trees = len(set(nodes_treeids))
# `n_features` arg
model_input = model_proto.graph.input[0]
input_shape = utils.find_input_shape(model_input)
assert len(input_shape) == 2
n_features = input_shape[1].dim_value
n_split_indices = len(set(split_indices))
largest_split_indices = max(split_indices)
if n_split_indices > n_features or largest_split_indices > n_features:
raise ValueError(
f"In the ONNX file, the input shape has {n_features} "
f"features and there are {n_split_indices} distinct split indices . "
f"with the largest index {largest_split_indices}. Validate you "
"set correctly the `initial_types` when converting your model to ONNX."
)
# `base_score` arg
base_score_attr = utils.find_attribute_in_node(
forest_node, "base_values", enforce=False
)
if base_score_attr is None:
base_score = 0.0
else:
assert base_score_attr.type == 6 # FLOATS
base_score = base_score_attr.floats[0]
# `learning_rate` arg
# NOTE: ONNX assumes the leaf weights have already been scaled by the
# learning rate, so we keep our forest's learning_rate scaled fixed at 1.0
learning_rate = 1.0
return forest_node, tree_args, n_trees, n_features, base_score, learning_rate