Source code for hatchet.query_matcher

# Copyright 2017-2020 Lawrence Livermore National Security, LLC and other
# Hatchet Project Developers. See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: MIT

from itertools import groupby
from numbers import Real
import re
import pandas as pd
from pandas import DataFrame
from pandas.core.indexes.multi import MultiIndex

# Flake8 to ignore this import, it does not recognize that eval("np.nan") needs
# numpy package
import numpy as np  # noqa: F401

from .node import Node, traversal_order


[docs]class QueryMatcher: """Process and apply queries to GraphFrames.""" def __init__(self, query=None): """Create a new QueryMatcher object. Arguments: query (list, optional): if provided, convert the contents of the high-level API query into an internal representation. """ # Initialize containers for query and memoization cache. self.query_pattern = [] self.search_cache = {} # If a high-level API list is provided, process it. if query is not None: assert isinstance(query, list) def _convert_dict_to_filter(attr_filter): """Converts high-level API attribute filter to a lambda""" compops = ("<", ">", "==", ">=", "<=", "<>", "!=") # , # Currently not supported # "is", "is not", "in", "not in") # This is a dict to work around Python's non-local variable # assignment rules. # # TODO: Replace this with the use of the "nonlocal" keyword # once Python 2.7 support is dropped. first_no_drop_indices = {"val": True} def filter_series(df_row): matches = True for k, v in attr_filter.items(): if k == "depth": node = df_row.name if isinstance(v, str) and v.lower().startswith(compops): matches = matches and eval( "{} {}".format(node._depth, v) ) elif isinstance(v, Real): matches = matches and (node._depth == v) else: raise InvalidQueryFilter( "Attribute {} has a numeric type. Valid filters for this attribute are a string starting with a comparison operator or a real number.".format( k ) ) continue if k == "node_id": node = df_row.name if isinstance(v, str) and v.lower().startswith(compops): matches = matches and eval( "{} {}".format(node._hatchet_nid, v) ) elif isinstance(v, Real): matches = matches and (node._hatchet_nid == v) else: raise InvalidQueryFilter( "Attribute {} has a numeric type. Valid filters for this attribute are a string starting with a comparison operator or a real number.".format( k ) ) continue if k not in df_row.keys(): return False if isinstance(df_row[k], str): if not isinstance(v, str): raise InvalidQueryFilter( "Value for attribute {} must be a string.", k ) if re.match(v + r"\Z", df_row[k]) is not None: matches = matches and True else: matches = matches and False elif isinstance(df_row[k], Real): if isinstance(v, str) and v.lower().startswith(compops): # compare nan metric value to numeric query # (e.g. np.nan > 5) if pd.isnull(df_row[k]): nan_str = "np.nan" # compare nan metric value to nan query # (e.g., np.nan == np.nan) if nan_str in v: matches = matches and eval( "pd.isnull({}) == True".format(nan_str) ) else: matches = matches and eval( "{} {}".format(nan_str, v) ) elif np.isinf(df_row[k]): inf_str = "np.inf" # compare inf metric value to inf query # (e.g., np.inf == np.inf) if inf_str in v: matches = matches and eval( "np.isinf({}) == True".format(inf_str) ) else: matches = matches and eval( "{} {}".format(inf_str, v) ) else: matches = matches and eval( "{} {}".format(df_row[k], v) ) elif isinstance(v, Real): matches = matches and (df_row[k] == v) else: raise InvalidQueryFilter( "Attribute {} has a numeric type. Valid filters for this attribute are a string starting with a comparison operator or a real number.".format( k ) ) else: raise InvalidQueryFilter( "Filter must be one of the following:\n * A regex string for a String attribute\n * A string starting with a comparison operator for a Numeric attribute\n * A number for a Numeric attribute\n" ) return matches def filter_dframe(df_row): if first_no_drop_indices["val"]: print( "===================================================================" ) print( "WARNING: You are performing a query without dropping index levels." ) print( " This is a valid operation, but it will significantly" ) print( " increase the time it takes for this operation to complete." ) print( " If you don't want the operation to take so long, call" ) print(" GraphFrame.drop_index_levels() before calling") print(" GraphFrame.filter()") print( "===================================================================\n" ) first_no_drop_indices["val"] = False matches = True node = df_row.name.to_frame().index[0][0] for k, v in attr_filter.items(): if k == "depth": if isinstance(v, str) and v.lower().startswith(compops): matches = matches and eval( "{} {}".format(node._depth, v) ) elif isinstance(v, Real): matches = matches and (node._depth == v) else: raise InvalidQueryFilter( "Attribute {} has a numeric type. Valid filters for this attribute are a string starting with a comparison operator or a real number.".format( k ) ) continue if k == "node_id": if isinstance(v, str) and v.lower().startswith(compops): matches = matches and eval( "{} {}".format(node._hatchet_nid, v) ) elif isinstance(v, Real): matches = matches and (node._hatchet_nid == v) else: raise InvalidQueryFilter( "Attribute {} has a numeric type. Valid filters for this attribute are a string starting with a comparison operator or a real number.".format( k ) ) continue if k not in df_row.columns: return False if df_row[k].apply(type).eq(str).all(): if not isinstance(v, str): raise InvalidQueryFilter( "Value for attribute {} must be a string.", k ) if ( df_row[k] .apply(lambda x: re.match(v + r"\Z", x) is not None) .any() ): matches = matches and True else: matches = matches and False elif df_row[k].apply(type).eq(Real).all(): if isinstance(v, str) and v.lower().startswith(compops): matches = ( matches and df_row[k] .apply(lambda x: eval("{} {}".format(x, v))) .any() ) elif isinstance(v, Real): matches = ( matches and df_row[k].apply(lambda x: x == v).any() ) else: raise InvalidQueryFilter( "Attribute {} has a numeric type. Valid filters for this attribute are a string starting with a comparison operator or a real number.".format( k ) ) else: raise InvalidQueryFilter( "Filter must be one of the following:\n * A regex string for a String attribute\n * A string starting with a comparison operator for a Numeric attribute\n * A number for a Numeric attribute\n" ) return matches def filter_choice(df_row): if isinstance(df_row, DataFrame): return filter_dframe(df_row) return filter_series(df_row) return filter_choice if attr_filter != {} else lambda row: True for elem in query: if isinstance(elem, dict): self._add_node(".", _convert_dict_to_filter(elem)) elif isinstance(elem, str) or isinstance(elem, int): self._add_node(elem) elif isinstance(elem, tuple): assert isinstance(elem[1], dict) if isinstance(elem[0], str) or isinstance(elem[0], int): self._add_node(elem[0], _convert_dict_to_filter(elem[1])) else: raise InvalidQueryPath( "The first value of a tuple entry in a path must be either a string or integer." ) else: raise InvalidQueryPath( "A query path must be a list containing String, Integer, Dict, or Tuple elements" )
[docs] def match(self, wildcard_spec=".", filter_func=lambda row: True): """Start a query with a root node described by the arguments. Arguments: wildcard_spec (str, optional, ".", "*", or "+"): the wildcard status of the node (follows standard Regex syntax) filter_func (callable, optional): a callable accepting only a row from a Pandas DataFrame that is used to filter this node in the query Returns: (QueryMatcher): The instance of the class that called this function (enables fluent design). """ if len(self.query_pattern) != 0: self.query_pattern = [] self._add_node(wildcard_spec, filter_func) return self
[docs] def rel(self, wildcard_spec=".", filter_func=lambda row: True): """Add another edge and node to the query. Arguments: wildcard_spec (str, optional, ".", "*", or "+"): the wildcard status of the node (follows standard Regex syntax) filter_func (callable, optional): a callable accepting only a row from a Pandas DataFrame that is used to filter this node in the query Returns: (QueryMatcher): The instance of the class that called this function (enables fluent design). """ self._add_node(wildcard_spec, filter_func) return self
[docs] def apply(self, gf): """Apply the query to a GraphFrame. Arguments: gf (GraphFrame): the GraphFrame on which to apply the query. Returns: (list): A list of lists representing the set of paths that match this query. """ self.search_cache = {} matches = [] visited = set() for root in sorted(gf.graph.roots, key=traversal_order): self._apply_impl(gf, root, visited, matches) assert len(visited) == len(gf.graph) return matches
def _add_node(self, wildcard_spec=".", filter_func=lambda row: True): """Add a node to the query. Arguments: wildcard_spec (str, optional, ".", "*", or "+"): the wildcard status of the node (follows standard Regex syntax) filter_func (callable, optional): a callable accepting only a row from a Pandas DataFrame that is used to filter this node in the query """ assert isinstance(wildcard_spec, int) or isinstance(wildcard_spec, str) assert callable(filter_func) if isinstance(wildcard_spec, int): for i in range(wildcard_spec): self.query_pattern.append((".", filter_func)) else: assert wildcard_spec == "." or wildcard_spec == "*" or wildcard_spec == "+" self.query_pattern.append((wildcard_spec, filter_func)) def _cache_node(self, gf, node): """Cache (Memoize) the parts of the query that the node matches. Arguments: gf (GraphFrame): the GraphFrame containing the node to be cached. node (Node): the Node to be cached. """ assert isinstance(node, Node) matches = [] # Applies each filtering function to the node to cache which # query nodes the current node matches. for i, node_query in enumerate(self.query_pattern): _, filter_func = node_query row = None if isinstance(gf.dataframe.index, MultiIndex): row = pd.concat([gf.dataframe.loc[node]], keys=[node], names=["node"]) else: row = gf.dataframe.loc[node] if filter_func(row): matches.append(i) self.search_cache[node._hatchet_nid] = matches def _match_0_or_more(self, gf, node, wcard_idx): """Process a "*" wildcard in the query on a subgraph. Arguments: gf (GraphFrame): the GraphFrame being queried. node (Node): the node being queried against the "*" wildcard. wcard_idx (int): the index associated with the "*" wildcard query. Returns: (list): a list of lists representing the paths rooted at "node" that match the "*" wildcard and/or the next query node. Will return None if there is no match for the "*" wildcard or the next query node. """ # Cache the node if it's not already cached if node._hatchet_nid not in self.search_cache: self._cache_node(gf, node) # If the node matches with the next non-wildcard query node, # end the recursion and return the node. if wcard_idx + 1 in self.search_cache[node._hatchet_nid]: return [[]] # If the node matches the "*" wildcard query, recursively # apply this function to the current node's children. Then, # collect their returned matches, and prepend the current node. elif wcard_idx in self.search_cache[node._hatchet_nid]: matches = [] if len(node.children) == 0: if wcard_idx == len(self.query_pattern) - 1: return [[node]] return None for child in sorted(node.children, key=traversal_order): sub_match = self._match_0_or_more(gf, child, wcard_idx) if sub_match is not None: matches.extend(sub_match) if len(matches) == 0: return None tmp = set(tuple(m) for m in matches) matches = [list(t) for t in tmp] return [[node] + m for m in matches] # If the current node doesn't match the current "*" wildcard or # the next non-wildcard query node, return None. else: if wcard_idx == len(self.query_pattern) - 1: return [[]] return None def _match_1_or_more(self, gf, node, wcard_idx): """Process a "+" wildcard in the query on a subgraph. Arguments: gf (GraphFrame): the GraphFrame being queried. node (Node): the node being queried against the "+" wildcard. wcard_idx (int): the index associated with the "+" wildcard query. Returns: (list): a list of lists representing the paths rooted at "node" that match the "+" wildcard and/or the next query node. Will return None if there is no match for the "+" wildcard or the next query node. """ # Cache the node if it's not already cached if node._hatchet_nid not in self.search_cache: self._cache_node(gf, node) # If the current node doesn't match the "+" wildcard, return None. if wcard_idx not in self.search_cache[node._hatchet_nid]: return None # Since a query can't end on a wildcard, return None if the # current node has no children. if len(node.children) == 0: return None # Use _match_0_or_more to collect all additional wildcard matches. matches = [] for child in sorted(node.children, key=traversal_order): sub_match = self._match_0_or_more(gf, child, wcard_idx) if sub_match is not None: matches.extend(sub_match) # Since _match_0_or_more will capture the query node that follows # the wildcard, if no paths were retrieved from that function, # the pattern does not continue after the "+" wildcard. Thus, # since a pattern cannot end on a wildcard, the function # returns None. if len(matches) == 0: return None return [[node] + m for m in matches] def _match_1(self, gf, node, idx): """Process a "." wildcard in the query on a subgraph. Arguments: gf (GraphFrame): the GraphFrame being queried. node (Node): the node being queried against the "." wildcard. wcard_idx (int): the index associated with the "." wildcard query. Returns: (list): A list of lists representing the children of "node" that match the "." wildcard being considered. Will return None if there are no matches for the "." wildcard. """ if node._hatchet_nid not in self.search_cache: self._cache_node(gf, node) matches = [] for child in sorted(node.children, key=traversal_order): # Cache the node if it's not already cached if child._hatchet_nid not in self.search_cache: self._cache_node(gf, child) if idx in self.search_cache[child._hatchet_nid]: matches.append([child]) # To be consistent with the other matching functions, return # None instead of an empty list. if len(matches) == 0: return None return matches def _match_pattern(self, gf, pattern_root, match_idx): """Try to match the query pattern starting at the provided root node. Arguments: gf (GraphFrame): the GraphFrame being queried. pattern_root (Node): the root node of the subgraph that is being queried. Returns: (list): A list of lists representing the paths rooted at "pattern_root" that match the query. """ assert isinstance(pattern_root, Node) # Starting query node pattern_idx = match_idx + 1 if ( self.query_pattern[match_idx][0] == "*" or self.query_pattern[match_idx][0] == "+" ): pattern_idx = 0 # Starting matching pattern matches = [[pattern_root]] while pattern_idx < len(self.query_pattern): # Get the wildcard type wcard, _ = self.query_pattern[pattern_idx] new_matches = [] # Consider each existing match individually so that more # nodes can be added to them. for m in matches: sub_match = [] # Get the portion of the subgraph that matches the next # part of the query. if wcard == ".": s = self._match_1(gf, m[-1], pattern_idx) if s is None: sub_match.append(s) else: sub_match.extend(s) elif wcard == "*": if len(m[-1].children) == 0: sub_match.append([]) else: for child in sorted(m[-1].children, key=traversal_order): s = self._match_0_or_more(gf, child, pattern_idx) if s is None: sub_match.append(s) else: sub_match.extend(s) elif wcard == "+": if len(m[-1].children) == 0: sub_match.append(None) else: for child in sorted(m[-1].children, key=traversal_order): s = self._match_1_or_more(gf, child, pattern_idx) if s is None: sub_match.append(s) else: sub_match.extend(s) else: raise InvalidQueryFilter( 'Query wildcards must be one of ".", "*", or "+"' ) # Merge the next part of the match path with the # existing part. for s in sub_match: if s is not None: new_matches.append(m + s) new_matches = [uniq_match for uniq_match, _ in groupby(new_matches)] # Overwrite the old matches with the updated matches matches = new_matches # If all the existing partial matches were not able to be # expanded into full matches, return None. if len(matches) == 0: return None # Update the query node pattern_idx += 1 return matches def _apply_impl(self, gf, node, visited, matches): """Traverse the subgraph with the specified root, and collect all paths that match the query. Arguments: gf (GraphFrame): the GraphFrame being queried. node (Node): the root node of the subgraph that is being queried. visited (set): a set that keeps track of what nodes have been visited in the traversal to minimize the amount of work that is repeated. matches (list): the list in which the final set of matches are stored. """ # If the node has already been visited (or is None for some # reason), skip it. if node is None or node._hatchet_nid in visited: return # Cache the node if it's not already cached if node._hatchet_nid not in self.search_cache: self._cache_node(gf, node) # If the node matches the starting/root node of the query, # try to get all query matches in the subgraph rooted at # this node. if self.query_pattern[0][0] == "*": if 1 in self.search_cache[node._hatchet_nid]: sub_match = self._match_pattern(gf, node, 1) if sub_match is not None: matches.extend(sub_match) if 0 in self.search_cache[node._hatchet_nid]: sub_match = self._match_pattern(gf, node, 0) if sub_match is not None: matches.extend(sub_match) # Note that the node is now visited. visited.add(node._hatchet_nid) # Continue the Depth First Search. for child in sorted(node.children, key=traversal_order): self._apply_impl(gf, child, visited, matches)
[docs]class InvalidQueryPath(Exception): """Raised when a query does not have the correct syntax"""
[docs]class InvalidQueryFilter(Exception): """Raised when a query filter does not have a valid syntax"""