Source code for bonobo.structs.graphs

import html
import json
from collections import namedtuple
from copy import copy

from bonobo.constants import BEGIN
from bonobo.util import get_name
from graphviz import ExecutableNotFound
from import Digraph

GraphRange = namedtuple('GraphRange', ['graph', 'input', 'output'])

[docs]class Graph: """ Represents a directed graph of nodes. """ name = '' def __init__(self, *chain): self.edges = {BEGIN: set()} self.named = {} self.nodes = [] self.add_chain(*chain) def __iter__(self): yield from self.nodes def __len__(self): """ Node count. """ return len(self.nodes) def __getitem__(self, key): return self.nodes[key]
[docs] def outputs_of(self, idx, create=False): """ Get a set of the outputs for a given node index. """ if create and not idx in self.edges: self.edges[idx] = set() return self.edges[idx]
[docs] def add_node(self, c): """ Add a node without connections in this graph and returns its index. """ idx = len(self.nodes) self.edges[idx] = set() self.nodes.append(c) return idx
[docs] def add_chain(self, *nodes, _input=BEGIN, _output=None, _name=None): """ Add a chain in this graph. """ if len(nodes): _input = self._resolve_index(_input) _output = self._resolve_index(_output) _first = None _last = None for i, node in enumerate(nodes): _last = self.add_node(node) if not i and _name: if _name in self.named: raise KeyError('Duplicate name {!r} in graph.'.format(_name)) self.named[_name] = _last if _first is None: _first = _last self.outputs_of(_input, create=True).add(_last) _input = _last if _output is not None: self.outputs_of(_input, create=True).add(_output) if hasattr(self, '_topologcally_sorted_indexes_cache'): del self._topologcally_sorted_indexes_cache return GraphRange(self, _first, _last) return GraphRange(self, None, None)
[docs] def copy(self): g = Graph() g.edges = copy(self.edges) g.named = copy(self.named) g.nodes = copy(self.nodes) return g
@property def topologically_sorted_indexes(self): """Iterate in topological order, based on networkx's topological_sort() function. """ try: return self._topologcally_sorted_indexes_cache except AttributeError: seen = set() order = [] explored = set() for i in self.edges: if i in explored: continue fringe = [i] while fringe: w = fringe[-1] # depth first search if w in explored: # already looked down this branch fringe.pop() continue seen.add(w) # mark as seen # Check successors for cycles and for new nodes new_nodes = [] for n in self.outputs_of(w): if n not in explored: if n in seen: # CYCLE !! raise RuntimeError("Graph contains a cycle.") new_nodes.append(n) if new_nodes: # Add new_nodes to fringe fringe.extend(new_nodes) else: # No new nodes so w is fully explored explored.add(w) order.append(w) fringe.pop() # done considering this node self._topologcally_sorted_indexes_cache = tuple(filter(lambda i: type(i) is int, reversed(order))) return self._topologcally_sorted_indexes_cache @property def graphviz(self): try: return self._graphviz except AttributeError: g = Digraph() g.attr(rankdir='LR') g.node('BEGIN', shape='point') for i in self.outputs_of(BEGIN): g.edge('BEGIN', str(i)) for ix in self.topologically_sorted_indexes: g.node(str(ix), label=get_name(self[ix])) for iy in self.outputs_of(ix): g.edge(str(ix), str(iy)) self._graphviz = g return self._graphviz def _repr_dot_(self): return str(self.graphviz) def _repr_html_(self): try: return '<div>{}</div><pre>{}</pre>'.format(self.graphviz._repr_svg_(), html.escape(repr(self))) except (ExecutableNotFound, FileNotFoundError) as exc: return '<strong>{}</strong>: {}'.format(type(exc).__name__, str(exc)) def _resolve_index(self, mixed): """ Find the index based on various strategies for a node, probably an input or output of chain. Supported inputs are indexes, node values or names. """ if mixed is None: return None if type(mixed) is int or mixed in self.edges: return mixed if isinstance(mixed, str) and mixed in self.named: return self.named[mixed] if mixed in self.nodes: return self.nodes.index(mixed) raise ValueError('Cannot find node matching {!r}.'.format(mixed))
def _get_graphviz_node_id(graph, i): escaped_index = str(i) escaped_name = json.dumps(get_name(graph[i])) return '{{{} [label={}]}}'.format(escaped_index, escaped_name)