Source code for hatchet.node

# Copyright 2017-2020 Lawrence Livermore National Security, LLC and other
# Hatchet Project Developers. See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: MIT

from functools import total_ordering

from .frame import Frame


[docs]def traversal_order(node): """Deterministic key function for sorting nodes in traversals.""" return (node.frame, id(node))
[docs]@total_ordering class Node: """A node in the graph. The node only stores its frame.""" def __init__(self, frame_obj, parent=None, hnid=-1, depth=-1): self.frame = frame_obj self._depth = depth self._hatchet_nid = hnid self.parents = [] if parent is not None: self.add_parent(parent) self.children = []
[docs] def add_parent(self, node): """Adds a parent to this node's list of parents.""" assert isinstance(node, Node) self.parents.append(node)
[docs] def add_child(self, node): """Adds a child to this node's list of children.""" assert isinstance(node, Node) self.children.append(node)
[docs] def paths(self, attrs=None): """List of tuples, one for each path from this node to any root. Arguments: attrs (str or list, optional): attribute(s) to extract from Frames Paths are tuples of Frame objects, or, if attrs is provided, they are paths containing the requested attributes. """ node_value = (self.frame,) if attrs is None else (self.frame.values(attrs),) if not self.parents: return [node_value] else: paths = [] for parent in self.parents: parent_paths = parent.paths(attrs) paths.extend([path + node_value for path in parent_paths]) return paths
[docs] def path(self, attrs=None): """Path to this node from root. Raises if there are multiple paths. Arguments: attrs (str or list, optional): attribute(s) to extract from Frames This is useful for trees (where each node only has one path), as it just gets the only element from ``self.paths``. This will fail with a MultiplePathError if there is more than one path to this node. """ paths = self.paths(attrs) if len(paths) > 1: raise MultiplePathError("Node has more than one path: " % paths) return paths[0]
[docs] def dag_equal(self, other, vs=None, vo=None): """Check if DAG rooted at self has the same structure as that rooted at other. """ if vs is None: vs = set() if vo is None: vo = set() vs.add(self._hatchet_nid) vo.add(other._hatchet_nid) # if number of children do not match, then nodes are not equal if len(self.children) != len(other.children): return False # sort children of each node by its frame ssorted = sorted(self.children, key=lambda x: x.frame) osorted = sorted(other.children, key=lambda x: x.frame) for self_child, other_child in zip(ssorted, osorted): # if frames do not match, then nodes are not equal if self_child.frame != other_child.frame: return False visited_s = self_child._hatchet_nid in vs visited_o = other_child._hatchet_nid in vo # check for duplicate nodes if visited_s != visited_o: return False # skip visited nodes if visited_s or visited_o: continue # recursive check for node equality if not self_child.dag_equal(other_child, vs, vo): return False return True
[docs] def traverse(self, order="pre", attrs=None, visited=None): """Traverse the tree depth-first and yield each node. Arguments: order (str): "pre" or "post" for preorder or postorder (default: pre) attrs (list or str, optional): if provided, extract these fields from nodes while traversing and yield them visited (dict, optional): dictionary in which each visited node's in-degree will be stored """ if order not in ("pre", "post"): raise ValueError("order must be one of 'pre' or 'post'") if visited is None: visited = {} key = id(self) if key in visited: # count the number of times we reached visited[key] += 1 return visited[key] = 1 def value(node): return node if attrs is None else node.frame.values(attrs) if order == "pre": yield value(self) for child in sorted(self.children, key=traversal_order): for item in child.traverse(order=order, attrs=attrs, visited=visited): yield item if order == "post": yield value(self)
def __hash__(self): return self._hatchet_nid def __eq__(self, other): return self._hatchet_nid == other._hatchet_nid def __lt__(self, other): return self._hatchet_nid < other._hatchet_nid def __gt__(self, other): return self._hatchet_nid > other._hatchet_nid def __str__(self): """Returns a string representation of the node.""" return str(self.frame)
[docs] def copy(self): """Copy this node without preserving parents or children.""" return Node(frame_obj=self.frame.copy())
[docs] @classmethod def from_lists(cls, lists): r"""Construct a hierarchy of nodes from recursive lists. For example, this will construct a simple tree: .. code-block:: python Node.from_lists( ["a", ["b", "d", "e"], ["c", "f", "g"], ] ) .. code-block:: console a / \ b c / | | \ d e f g And this will construct a simple diamond DAG: .. code-block:: python d = Node(Frame(name="d")) Node.from_lists( ["a", ["b", d], ["c", d] ] ) .. code-block:: console a / \ b c \ / d In the above examples, the 'a' represents a Node with its `frame == Frame(name="a")`. """ def _from_lists(lists, parent): if isinstance(lists, (tuple, list)): if isinstance(lists[0], Node): node = lists[0] elif isinstance(lists[0], str): node = Node(Frame(name=lists[0])) children = lists[1:] for val in children: _ = _from_lists(val, node) elif isinstance(lists, str): node = Node(Frame(name=lists)) elif isinstance(lists, Node): node = lists else: raise ValueError("Argument must be str, list, or Node: %s" % lists) if parent: node.add_parent(parent) parent.add_child(node) return node return _from_lists(lists, None)
def __repr__(self): return "Node({%s})" % ", ".join( "%s: %s" % (repr(k), repr(v)) for k, v in sorted(self.frame.attrs.items()) )
[docs]class MultiplePathError(Exception): """Raised when a node is asked for a single path but has multiple."""