Source code for noob.scheduler

import contextlib
import logging
from collections import defaultdict, deque
from collections.abc import MutableSequence
from copy import deepcopy
from dataclasses import dataclass, field
from datetime import UTC, datetime
from functools import cached_property
from itertools import count
from typing import Self
from uuid import uuid4

from noob.event import Event, MetaEvent, MetaEventType, MetaSignal
from noob.exceptions import AlreadyDoneError, EpochCompletedError, EpochExistsError, NotAddedError
from noob.logging import init_logger
from noob.node import Edge, NodeSpecification
from noob.toposort import GraphItem, NodeSignal, TopoSorter
from noob.types import Epoch, NodeID, SignalName

_VIRTUAL_NODES = ("input", "assets")
"""
Virtual nodes that don't actually exist as nodes,
but can be depended on 
(and can be present or absent, and so shouldn't be marked as trivially done)
"""


[docs] @dataclass() class Scheduler: nodes: dict[str, NodeSpecification] edges: list[Edge] source_nodes: list[NodeID] = field(default_factory=list) _logger: logging.Logger = field(default_factory=lambda: init_logger("noob.scheduler")) _clock: count = field(default_factory=count) _epochs: dict[Epoch, TopoSorter] = field(default_factory=dict) _subepochs: dict[Epoch, set[Epoch]] = field(default_factory=lambda: defaultdict(set)) _epoch_log: deque[int] = field(default_factory=lambda: deque(maxlen=100)) _subgraphs: dict[NodeID, tuple[dict[str, NodeSpecification], list[Edge]]] = field( default_factory=dict ) _frozen_sorters: dict[tuple[NodeID, ...], TopoSorter] = field(default_factory=dict) def __post_init__(self): self._get_sources()
[docs] @classmethod def from_specification(cls, nodes: dict[str, NodeSpecification], edges: list[Edge]) -> Self: """ Create an instance of a Scheduler from :class:`.NodeSpecification` and :class:`.Edge` """ return cls(nodes=nodes, edges=edges)
def _get_sources(self) -> Self: """ Get the IDs of the nodes that do not depend on other nodes. * `input` nodes are special implicit source nodes. Other nodes * CAN depend on it and still be a source node. """ if not self.source_nodes: graph = self._init_graph() self.source_nodes = [ id_ for id_ in graph.ready_nodes if id_ not in _VIRTUAL_NODES and not isinstance(id_, NodeSignal) ] return self @property def subepochs(self) -> dict[Epoch, set[Epoch]]: return self._subepochs
[docs] @cached_property def graph_signals(self) -> set[tuple[NodeID, SignalName]]: """ The set of (node id, signal) tuples that are depended on in the graph. Nodes can have many more signals than we actually care about for structuring the graph, this set is only the ones that we care about. """ return {(e.source_node, e.source_signal) for e in self.edges}
[docs] def add_epoch(self, epoch: int | Epoch | None = None) -> Epoch: """ Add another epoch with a prepared graph to the scheduler. """ if epoch is not None: if isinstance(epoch, int): this_epoch = Epoch(epoch) elif isinstance(epoch, Epoch): this_epoch = epoch else: raise TypeError("Can only create an epoch from an epoch or integer") # ensure that the next iteration of the clock will return the next number # if we create epochs out of order self._clock = count( max([this_epoch[0].epoch, *[ep[0].epoch for ep in self._epochs], *self._epoch_log]) + 1 ) else: this_epoch = Epoch(next(self._clock)) if this_epoch in self._epochs: raise EpochExistsError(f"Epoch {this_epoch} is already scheduled") elif this_epoch in self._epoch_log: raise EpochCompletedError(f"Epoch {this_epoch} has already been completed!") graph = self._init_graph(epoch=this_epoch) self._epochs[this_epoch] = graph return this_epoch
[docs] def add_subepoch(self, epoch: Epoch) -> Epoch: """ Add subepoch! Creates a topo sorter with all the nodes downstream of the node that created the epoch. """ if epoch.parent is None: raise ValueError(f"Cannot create a subepoch for root epoch {epoch}") parent_epoch = self[epoch.parent] sorter = self._init_graph(epoch) # mark any nodes that are completed in the parent as completed in the subepoch # EXCEPT don't expire the node that induced the subepoch or its signals - # we expect that the subepoch is typically created during an `update` call # where we'll be handling done or expiredness of the signals separately. parent_deps = set(sorter.node_info) exclude_current = sorter.signals[epoch[-1].node_id] | {epoch[-1].node_id} for parent_dep in parent_deps: if parent_dep in parent_epoch.ran_nodes: sorter.done(parent_dep) elif parent_dep in parent_epoch.done_nodes and parent_dep not in exclude_current: sorter.mark_expired(parent_dep, unlock_optionals=False) elif parent_dep in parent_epoch.out_nodes: sorter.mark_out(parent_dep) self._epochs[epoch] = sorter for parent in epoch.parents: self._subepochs[parent].add(epoch) # a node inducing subepochs expires the node in the (immediate) parent epoch if epoch[-1].node_id not in parent_epoch.done_nodes: self.expire(epoch.parent, epoch[-1].node_id, with_signals=False, unlock_optionals=False) return epoch
[docs] def is_active(self, epoch: Epoch | None = None) -> bool: """ Graph remains active while it holds at least one epoch that is active. """ if epoch is not None: if epoch not in self._epochs: # if an epoch has been completed and had its graph cleared, it's no longer active # if an epoch has not been started, it is also not active. return False return any(self._epochs[e].is_active() for e in {*self._subepochs[epoch], epoch}) else: return any(graph.is_active() for graph in self._epochs.values())
[docs] def get_ready( self, epoch: Epoch | None = None, node_id: NodeID | None = None ) -> list[MetaEvent]: """ Output the set of nodes that are ready across different epochs. Args: epoch (Epoch | None): if an Epoch, get ready events for that epoch, if ``None`` , get ready events for all epochs. node_id (str | None): If present, only get ready events for a single node """ if epoch is not None: graphs = [ (ep, self._epochs[ep]) for ep in {*self._subepochs.get(epoch, set()), epoch} if ep in self._epochs ] else: graphs = list(self._epochs.items()) graphs = sorted( graphs, key=lambda g: (tuple(e.node_id for e in g[0]), tuple(e.epoch for e in g[0])) ) ready_nodes = [] for epoch, graph in graphs: for node in graph.get_ready(node_id): if isinstance(node, NodeSignal): self._logger.warning( "Scheduler attempted to return signal tuple %s in %s - " "something is wrong with how the graph is instantiated or run, " "or a node is emitting incorrect events manually, " "all signals should be marked done/expired by events passed in `update`. " "Ignoring - nodes downstream of this signal will not run.", node, epoch, ) graph.mark_expired(node) continue elif node in _VIRTUAL_NODES or (node not in self.nodes or self.nodes[node].enabled): ready_nodes.append( MetaEvent( id=uuid4().int, timestamp=datetime.now(), node_id="meta", signal=MetaEventType.NodeReady, epoch=epoch, value=node, ) ) return ready_nodes
[docs] def node_is_ready(self, node: NodeID, epoch: Epoch | None = None) -> bool: """ Check if a single node is ready in a single or any epoch Args: node (NodeID): the node to check epoch (int | None): the epoch to check, if ``None`` , any epoch """ # slight duplication of the above because we don't want to *get* the ready nodes, # which marks them as "out" in the TopoSorter # if we've already run this, the node is ready - don't create another epoch if epoch in self._epoch_log: return True graphs = ( self._epochs.items() if epoch is None else [(ep, self[ep]) for ep in [epoch, *self._subepochs[epoch]]] ) is_ready = any(node in graph.ready_nodes for epoch, graph in graphs) return is_ready
[docs] def node_is_done(self, node: NodeID, epoch: Epoch) -> bool: """Node is expired or done in specified epoch""" if epoch in self._epoch_log: return True if self._subepochs[epoch]: return all(node in self._epochs[e].done_nodes for e in self._subepochs[epoch] | {epoch}) else: return node in self._epochs[epoch].done_nodes
def __getitem__(self, epoch: Epoch | int) -> TopoSorter: if epoch == -1: if len(self._epochs) == 1: return next(iter(self._epochs.values())) else: max_epoch = max(*[e[0].epoch for e in self._epochs]) return self._epochs[Epoch(max_epoch)] elif isinstance(epoch, int): epoch = Epoch(epoch) if epoch not in self._epochs: if len(epoch) == 1: self.add_epoch(epoch) else: self.add_subepoch(epoch) return self._epochs[epoch]
[docs] def sources_finished(self, epoch: Epoch | None = None) -> bool: """ Check the source nodes of the given epoch have been processed. If epoch is None, check the source nodes of the latest epoch. """ if epoch is None and len(self._epochs) == 0: return True graph = self[-1] if epoch is None else self._epochs[epoch] return all(src in graph.done_nodes for src in self.source_nodes)
[docs] def update( self, events: MutableSequence[Event | MetaEvent] | MutableSequence[Event] ) -> MutableSequence[Event] | MutableSequence[Event | MetaEvent]: """ When a set of events are received, update the graphs within the scheduler. Currently only has :meth:`TopoSorter.done` implemented. """ if not events: return events end_events: MutableSequence[MetaEvent] = [] nodes_done = set() # process subepochs first so they're created when we handle parent epochs events = sorted(events, key=lambda ee: len(ee["epoch"]), reverse=True) for e in events: if e["node_id"] == "meta": continue elif (node_done := (e["epoch"], e["node_id"])) not in nodes_done: nodes_done.add(node_done) # FIXME: This exception suppression is a *bit* broad - fix underlying issue # The zmq runner has an incomplete graph, and so sometimes we don't have # all the nodes in the graph when we go to mark the node done. with contextlib.suppress(AlreadyDoneError, NotAddedError): epoch_ended = self.done(e["epoch"], e["node_id"], with_signals=False) if epoch_ended: end_events.append(epoch_ended) continue if (e["node_id"], e["signal"]) not in self.graph_signals: continue if e["value"] == MetaSignal.NoEvent: epoch_ended = self.expire( epoch=e["epoch"], node_id=e["node_id"], signal=e["signal"] ) else: epoch_ended = self.done(epoch=e["epoch"], node_id=e["node_id"], signal=e["signal"]) if epoch_ended: end_events.append(epoch_ended) ret_events = [*events, *end_events] return ret_events
[docs] def done( self, epoch: Epoch, node_id: str, signal: SignalName | None = None, with_signals: bool = True, ) -> MetaEvent | None: """ Mark a node in a given epoch as done. Args: with_signals (bool): When marking this node as done, also mark all its signals as done. """ if epoch[0].epoch in self._epoch_log: self._logger.debug( "Marking node %s as done in epoch %s, " "but epoch was already completed. ignoring", node_id, epoch, ) return None to_mark = NodeSignal(node_id, signal) if signal is not None else node_id try: self[epoch].done(to_mark) except AlreadyDoneError as e: if not self._subepochs[epoch]: raise AlreadyDoneError(f"Node {node_id} already done in {epoch}") from e self._done_subepochs(epoch, node_id, signal) for parent in epoch.parents: self[parent].mark_expired(to_mark, unlock_optionals=False) if signal is None and with_signals: self[epoch].done(*self[epoch].signals[node_id].difference(self[epoch].done_nodes)) if not self.is_active(epoch): return self.end_epoch(epoch) return None
[docs] def expire( self, epoch: Epoch, node_id: str, signal: SignalName | None = None, with_signals: bool = True, unlock_optionals: bool = True, ) -> MetaEvent | None: """ Mark a node as having been completed without making its dependent nodes ready. i.e. when the node emitted ``NoEvent`` """ to_mark = NodeSignal(node_id, signal) if signal is not None else node_id self[epoch].mark_expired(to_mark, unlock_optionals=unlock_optionals) # if any immediate successors are already marked as "ready," we also want to cancel them. if info := self[epoch].node_info.get(to_mark): for successor in info.successors: self[epoch].ready_nodes.discard(successor) if signal is None and with_signals: for graph_node in self[epoch].signals[node_id]: self.expire( epoch, node_id=node_id, signal=graph_node[1], unlock_optionals=unlock_optionals ) if not self.is_active(epoch): return self.end_epoch(epoch) return None
[docs] def epoch_completed(self, epoch: Epoch) -> bool: """ Check if the epoch has been completed. """ previously_completed = ( len(self._epoch_log) > 0 and epoch not in self._epochs and (epoch in self._epoch_log or epoch < min(self._epoch_log)) ) active_completed = epoch in self._epochs and not any( self._epochs[ep].is_active() for ep in [epoch, *self._subepochs[epoch]] ) return previously_completed or active_completed
[docs] def end_epoch(self, epoch: Epoch | int | None = None) -> MetaEvent | None: if epoch is None or epoch == -1: if len(self._epochs) == 0: return None ep = list(self._epochs)[-1] elif isinstance(epoch, int): ep = Epoch(epoch) elif isinstance(epoch, Epoch): ep = epoch else: raise TypeError("Can only end an epoch with an integer or Epoch") self._logger.debug("Ending epoch %s", ep) if len(ep) == 1: self._epoch_log.append(ep[0].epoch) for subep in {ep, *self._subepochs[ep]}: with contextlib.suppress(KeyError): del self._epochs[subep] return MetaEvent( id=uuid4().int, timestamp=datetime.now(UTC), node_id="meta", signal=MetaEventType.EpochEnded, epoch=ep, value=ep, )
[docs] def enable_node(self, node_id: str) -> None: """ Enable edges attached to the node and the NodeSpecification enable switches to True """ self.nodes[node_id].enabled = True self._frozen_sorters = {}
[docs] def disable_node(self, node_id: str) -> None: """ Disable edges attached to the node and the NodeSpecification enable switches to False """ self.nodes[node_id].enabled = False self._frozen_sorters = {} for graph in self._epochs.values(): graph.mark_expired(node_id)
[docs] def clear(self) -> None: """ Remove epoch records, restarting the scheduler """ self._epochs = {} self._epoch_log = deque(maxlen=100)
def _init_graph(self, epoch: Epoch | None = None) -> TopoSorter: """ Produce a :class:`.TopoSorter` based on the graph induced by a set of :class:`.Node` and a set of :class:`.Edge` that yields node ids. """ frozen_key = ("tube",) if epoch is None else tuple(e.node_id for e in epoch) if frozen_key not in self._frozen_sorters: if epoch and epoch.parent: nodes, edges = self._subgraph(epoch[-1].node_id) sorter = TopoSorter(nodes, edges) else: sorter = TopoSorter(self.nodes, self.edges) self._frozen_sorters[frozen_key] = sorter return deepcopy(self._frozen_sorters[frozen_key])
[docs] def has_cycle(self) -> bool: """ Checks that the graph is acyclic. """ graph = self._init_graph() cycle = graph.find_cycle() return bool(cycle)
[docs] def generations(self) -> list[tuple[GraphItem, ...]]: """ Get the topological generations of the graph: tuples for each set of nodes that can be run at the same time. Order within a generation is not guaranteed to be stable. """ sorter = self._init_graph() generations = [] while sorter.is_active(): ready = sorter.get_ready() generations.append(ready) sorter.done(*ready) return generations
[docs] def asset_generations(self) -> dict[NodeID, list[tuple[str, ...]]]: """ :meth:`.generations` except only including nodes with direct dependencies on assets, to determine when the asset should be initialized vs. received in the ZMQ Runner. Packed in a dictionary with the asset ID as the key, and the value as the generations for that asset. """ generations = defaultdict(list) asset_ids = set(e.source_signal for e in self.edges if e.source_node == "assets") for gen in self.generations(): for asset in asset_ids: gen_deps = tuple( [ g for g in gen if not isinstance(g, NodeSignal) and any( e.source_node == "assets" and e.source_signal == asset and e.target_node == g for e in self.edges ) ] ) if gen_deps: generations[asset].append(gen_deps) return generations
[docs] def upstream_nodes(self, node: NodeID) -> set[NodeID]: """ All the nodes that have an effect on the given node From: * Dependencies * If the node has optional dependencies, nodes whose NoEvents it should listen to """ upstream = {e.source_node for e in self.edges if e.target_node == node} sorter = self._init_graph() for item, info in sorter.node_info.items(): if node in info.optional_successors: upstream.add(item[0] if isinstance(item, NodeSignal) else item) return upstream
def _subgraph(self, node_id: str) -> tuple[dict[str, NodeSpecification], list[Edge]]: """ Subgraph that is downstream of a given node (including the node itself). """ from noob.tube import downstream_nodes if node_id not in self._subgraphs: downstream = downstream_nodes(self.edges, node_id) self._subgraphs[node_id] = ( {node_id: self.nodes[node_id] for node_id in downstream if node_id in self.nodes}, [e for e in self.edges if e.target_node in downstream], ) return self._subgraphs[node_id] def _done_subepochs( self, epoch: Epoch, node_id: NodeID, signal: SignalName | None = None ) -> None: """ Called when a node in a parent epoch is marked done - mark the node done in all subepochs, but ensure that nodes that are exclusively downstream of this node (i.e. no dependencies on nodes within the mapped subepoch) are removed from the graph. This is to support gather-like operations from non-gather nodes in 3rd party tubes: nodes downstream of both this node and other nodes in the subepoch run in subepochs, but nodes that are exclusively downstream of this node only run in the parent epoch """ from noob.tube import downstream_nodes if not self._subepochs[epoch]: return our_subgraph = set(self._subgraph(node_id)[0]) _exclusive_subgraphs = {} to_mark = NodeSignal(node_id, signal) if signal is not None else node_id for subepoch in self._subepochs[epoch]: if ( to_mark in self._epochs[subepoch].ran_nodes or to_mark not in self._epochs[subepoch].node_info ): # fine continue elif to_mark in self._epochs[subepoch].done_nodes: # needs to be resurrected self._epochs[subepoch].resurrect(to_mark) self._epochs[subepoch].done(to_mark) # mark all nodes that are exclusively downstream of this node expired subep_node = subepoch[-1].node_id if subep_node not in _exclusive_subgraphs: _exclusive_subgraphs[subep_node] = downstream_nodes( self.edges, subep_node, exclude={node_id} ) exclusive_subgraph = our_subgraph - _exclusive_subgraphs[subep_node] - {node_id} for exclusive in exclusive_subgraph: self._epochs[subepoch].mark_expired(exclusive)