"""Class implementation of the transport graph in the workflow graph."""
import base64
import json
import platform
from copy import deepcopy
from typing import Any, Callable, Dict
import cloudpickle
import networkx as nx
from .._shared_files.defaults import parameter_prefix
from .._shared_files.util_classes import RESULT_STATUS
class TransportableObject:
"""
A function is converted to a transportable object by serializing it using cloudpickle
and then whenever executing it, the transportable object is deserialized. The object
will also contain additional info like the python version used to serialize it.
Attributes:
_object: The serialized object.
python_version: The python version used on the client's machine.
"""
def __init__(self, obj: Any) -> None:
self._object = base64.b64encode(cloudpickle.dumps(obj)).decode("utf-8")
self.python_version = platform.python_version()
self.object_string = str(obj)
try:
self._json = json.dumps(obj)
except TypeError as ex:
self._json = ""
self.attrs = {"doc": getattr(obj, "__doc__", ""), "name": getattr(obj, "__name__", "")}
def __eq__(self, obj) -> bool:
return self.__dict__ == obj.__dict__ if isinstance(obj, TransportableObject) else False
def get_deserialized(self) -> Callable:
"""
Get the deserialized transportable object.
Args:
None
Returns:
function: The deserialized object/callable function.
"""
return cloudpickle.loads(base64.b64decode(self._object.encode("utf-8")))
@property
def json(self):
return self._json
def to_dict(self) -> dict:
"""Return a JSON-serializable dictionary representation of self"""
return {"type": "TransportableObject", "attributes": self.__dict__.copy()}
@staticmethod
def from_dict(object_dict) -> "TransportableObject":
"""Rehydrate a dictionary representation
Args:
object_dict: a dictionary representation returned by `to_dict`
Returns:
A `TransportableObject` represented by `object_dict`
"""
sc = TransportableObject(None)
sc.__dict__ = object_dict["attributes"]
return sc
def get_serialized(self) -> str:
"""
Get the serialized transportable object.
Args:
None
Returns:
object: The serialized transportable object.
"""
return self._object
def serialize(self) -> bytes:
"""
Serialize the transportable object.
Args:
None
Returns:
pickled_object: The serialized object alongwith the python version.
"""
return cloudpickle.dumps(
{
"object": self.get_serialized(),
"object_string": self.object_string,
"json": self._json,
"attrs": self.attrs,
"py_version": self.python_version,
}
)
def serialize_to_json(self) -> str:
"""
Serialize the transportable object to JSON.
Args:
None
Returns:
A JSON string representation of the transportable object
"""
return json.dumps(self.to_dict())
@staticmethod
def deserialize_from_json(json_string: str) -> str:
"""
Reconstruct a transportable object from JSON
Args:
json_string: A JSON string representation of a TransportableObject
Returns:
A TransportableObject instance
"""
object_dict = json.loads(json_string)
return TransportableObject.from_dict(object_dict)
@staticmethod
def make_transportable(obj) -> "TransportableObject":
if isinstance(obj, TransportableObject):
return obj
else:
return TransportableObject(obj)
@staticmethod
def deserialize(data: bytes) -> "TransportableObject":
"""
Deserialize the transportable object.
Args:
data: Cloudpickled function.
Returns:
object: The deserialized transportable object.
"""
obj = cloudpickle.loads(data)
sc = TransportableObject(None)
sc._object = obj["object"]
sc._json = obj["json"]
sc.attrs = obj["attrs"]
sc.python_version = obj["py_version"]
return sc
@staticmethod
def deserialize_list(collection: list) -> list:
"""
Recursively deserializes a list of TransportableObjects. More
precisely, `collection` is a list, each of whose entries is
assumed to be either a `TransportableObject`, a list, or dict`
"""
new_list = []
for item in collection:
if isinstance(item, TransportableObject):
new_list.append(item.get_deserialized())
elif isinstance(item, list):
new_list.append(TransportableObject.deserialize_list(item))
elif isinstance(item, dict):
new_list.append(TransportableObject.deserialize_dict(item))
else:
raise TypeError("Couldn't deserialize collection")
return new_list
@staticmethod
def deserialize_dict(collection: dict) -> dict:
"""
Recursively deserializes a dict of TransportableObjects. More
precisely, `collection` is a dict, each of whose entries is
assumed to be either a `TransportableObject`, a list, or dict`
"""
new_dict = {}
for k, item in collection.items():
if isinstance(item, TransportableObject):
new_dict[k] = item.get_deserialized()
elif isinstance(item, list):
new_dict[k] = TransportableObject.deserialize_list(item)
elif isinstance(item, dict):
new_dict[k] = TransportableObject.deserialize_dict(item)
else:
raise TypeError("Couldn't deserialize collection")
return new_dict
def encode_metadata(metadata: dict) -> dict:
encoded_metadata = deepcopy(metadata)
if "executor" in metadata:
if "executor_data" not in metadata:
encoded_metadata["executor_data"] = {}
if metadata["executor"] is not None and not isinstance(metadata["executor"], str):
encoded_executor = metadata["executor"].to_dict()
encoded_metadata["executor"] = encoded_executor["short_name"]
encoded_metadata["executor_data"] = encoded_executor
if "workflow_executor" in metadata:
if "workflow_executor_data" not in metadata:
encoded_metadata["workflow_executor_data"] = {}
if metadata["workflow_executor"] is not None and not isinstance(
metadata["workflow_executor"], str
):
encoded_wf_executor = metadata["workflow_executor"].to_dict()
encoded_metadata["workflow_executor"] = encoded_wf_executor["short_name"]
encoded_metadata["workflow_executor_data"] = encoded_wf_executor
if "deps" in metadata and metadata["deps"] is not None:
for dep_type, dep_object in metadata["deps"].items():
if dep_object and not isinstance(dep_object, dict):
encoded_metadata["deps"][dep_type] = dep_object.to_dict()
if "call_before" in metadata and metadata["call_before"] is not None:
for i, dep in enumerate(metadata["call_before"]):
if not isinstance(dep, dict):
encoded_metadata["call_before"][i] = dep.to_dict()
if "call_after" in metadata and metadata["call_after"] is not None:
for i, dep in enumerate(metadata["call_after"]):
if not isinstance(dep, dict):
encoded_metadata["call_after"][i] = dep.to_dict()
if "triggers" in metadata:
if isinstance(metadata["triggers"], list):
encoded_metadata["triggers"] = []
for tr in metadata["triggers"]:
if isinstance(tr, dict):
encoded_metadata["triggers"].append(tr)
else:
encoded_metadata["triggers"].append(tr.to_dict())
else:
encoded_metadata["triggers"] = metadata["triggers"]
return encoded_metadata
class _TransportGraph:
"""
A TransportGraph is the most essential part of the whole workflow. This contains
all the information about each electron and lattice required for determining how,
when, and where to execute the workflow. The TransportGraph contains a directed graph
which is used to determine the execution order of the nodes. Each node in this graph
is an electron which is ready to be executed.
Attributes:
_graph: The directed graph object of type networkx.DiGraph().
lattice_metadata: The lattice metadata of the transport graph.
"""
def __init__(self) -> None:
self._graph = nx.MultiDiGraph()
self.lattice_metadata = None
self.dirty_nodes = []
self._default_node_attrs = {
"start_time": None,
"end_time": None,
"status": RESULT_STATUS.NEW_OBJECT,
"output": None,
"error": None,
"sub_dispatch_id": None,
"sublattice_result": None,
"stdout": None,
"stderr": None,
}
def add_node(self, name: str, function: Callable, metadata: Dict, **attr) -> int:
"""
Adds a node to the graph.
Args:
name: The name of the node.
function: The function to be executed.
metadata: The metadata of the node.
attr: Any other attributes that need to be added to the node.
Returns:
node_key: The node id.
"""
node_id = len(self._graph.nodes)
self._graph.add_node(
node_id,
name=name,
function=TransportableObject(function),
metadata=metadata,
**attr,
)
return node_id
def add_edge(self, x: int, y: int, edge_name: Any, **attr) -> None:
"""
Adds an edge to the graph and assigns a name to it. Edge insertion
order is not preserved in networkx. So in case of positional arguments
passed into the electron, we need to preserve the order when we
deserialize the request in the lattice.
Args:
x: The node id for first node.
y: The node id for second node.
edge_name: The name to be assigned to the edge.
attr: Any other attributes that need to be added to the edge.
Returns:
None
Raises:
ValueError: If the edge already exists.
"""
self._graph.add_edge(x, y, edge_name=edge_name, **attr)
def reset(self) -> None:
"""
Resets the graph.
Args:
None
Returns:
None
"""
self._graph = nx.MultiDiGraph()
def get_node_value(self, node_key: int, value_key: str) -> Any:
"""
Get a specific value from a node depending upon the value key.
Args:
node_key: The node id.
value_key: The value key.
Returns:
value: The value from the node stored at the value key.
Raises:
KeyError: If the value key or node key is not found.
"""
return self._graph.nodes[node_key][value_key]
def set_node_value(self, node_key: int, value_key: int, value: Any) -> None:
"""
Set a certain value of a node. This allows for saving custom data
in the graph nodes.
Args:
node_key: The node id.
value_key: The value key.
value: The value to be set at value_key position of the node.
Returns:
None
Raises:
KeyError: If the node key is not found.
"""
self.dirty_nodes.append(node_key)
self._graph.nodes[node_key][value_key] = value
def get_edge_data(self, dep_key: int, node_key: int) -> Any:
"""
Get the metadata for all edges between two nodes.
Args:
dep_key: The node id for first node.
node_key: The node id for second node.
Returns:
values: A dict {edge_key : value}
Raises:
KeyError: If the edge is not found.
"""
return self._graph.get_edge_data(dep_key, node_key)
def get_dependencies(self, node_key: int) -> list:
"""
Gets the parent node ids of a node.
Args:
node_key: The node id.
Returns:
parents: The dependencies of the node.
"""
return list(self._graph.predecessors(node_key))
def get_internal_graph_copy(self) -> nx.MultiDiGraph:
"""
Get a copy of the internal directed graph
to avoid modifying the original graph.
Args:
None
Returns:
graph: A copy of the internal directed graph.
"""
return self._graph.copy()
def reset_node(self, node_id: int) -> None:
"""Reset node values to starting state."""
for node_attr, default_val in self._default_node_attrs.items():
self.set_node_value(node_id, node_attr, default_val)
def _replace_node(self, node_id: int, new_attrs: Dict[str, Any]) -> None:
"""Replace node data with new attribute values and flag descendants (used in re-dispatching)."""
metadata = self.get_node_value(node_id, "metadata")
metadata.update(new_attrs["metadata"])
serialized_callable = TransportableObject.from_dict(new_attrs["function"])
self.set_node_value(node_id, "function", serialized_callable)
self.set_node_value(node_id, "function_string", new_attrs["function_string"])
self.set_node_value(node_id, "name", new_attrs["name"])
self._reset_descendants(node_id)
def _reset_descendants(self, node_id: int) -> None:
"""Reset node and all its descendants to starting state."""
try:
if self.get_node_value(node_id, "status") == RESULT_STATUS.NEW_OBJECT:
return
except Exception:
return
self.reset_node(node_id)
for successor in self._graph.neighbors(node_id):
self._reset_descendants(successor)
def apply_electron_updates(self, electron_updates: Dict[str, Callable]) -> None:
"""Replace transport graph node data based on the electrons that need to be updated during re-dispatching."""
for n in self._graph.nodes:
name = self.get_node_value(n, "name")
if name in electron_updates:
self._replace_node(n, electron_updates[name])
def serialize(self, metadata_only: bool = False) -> bytes:
"""
Convert transport graph object to JSON to be used in the workflow scheduler.
Convert transport graph networkx.DiGraph object into JSON format, filter out
computation specific attributes and lastly add the lattice metadata. This also
serializes the function Callable into by base64 encoding the cloudpickled result.
Args:
metadata_only: If true, only serialize the metadata.
Returns:
str: json string representation of transport graph
"""
data = nx.readwrite.node_link_data(self._graph)
for idx, node in enumerate(data["nodes"]):
data["nodes"][idx]["function"] = data["nodes"][idx].pop("function").serialize()
if metadata_only:
parameter_node_id = [
i
for i, node in enumerate(data["nodes"])
if node["name"].startswith(parameter_prefix)
]
for node in data["nodes"].copy():
if node["id"] in parameter_node_id:
data["nodes"].remove(node)
for idx, node in enumerate(data["nodes"]):
for field in data["nodes"][idx].copy():
if field != "metadata":
data["nodes"][idx].pop(field, None)
for idx, node in enumerate(data["links"]):
for name in data["links"][idx].copy():
if name not in ["source", "target"]:
data["links"][idx].pop("edge_name", None)
data["lattice_metadata"] = self.lattice_metadata
return cloudpickle.dumps(data)
def serialize_to_json(self, metadata_only: bool = False) -> str:
"""
Convert transport graph object to JSON to be used in the workflow scheduler.
Convert transport graph networkx.DiGraph object into JSON format, filter out
computation specific attributes and lastly add the lattice metadata. This also
serializes the function Callable into by base64 encoding the cloudpickled result.
Args:
metadata_only: If true, only serialize the metadata.
Returns:
str: json string representation of transport graph
Note: serialize_to_json converts metadata objects into dictionary representations.
"""
data = nx.readwrite.node_link_data(self._graph)
for idx, node in enumerate(data["nodes"]):
data["nodes"][idx]["function"] = data["nodes"][idx].pop("function").to_dict()
if "value" in node:
node["value"] = node["value"].to_dict()
if "metadata" in node:
node["metadata"] = encode_metadata(node["metadata"])
if metadata_only:
parameter_node_id = [
i
for i, node in enumerate(data["nodes"])
if node["name"].startswith(parameter_prefix)
]
for node in data["nodes"].copy():
if node["id"] in parameter_node_id:
data["nodes"].remove(node)
for idx, node in enumerate(data["nodes"]):
for field in data["nodes"][idx].copy():
if field != "metadata":
data["nodes"][idx].pop(field, None)
for idx, node in enumerate(data["links"]):
for name in data["links"][idx].copy():
if name not in ["source", "target"]:
data["links"][idx].pop("edge_name", None)
data["lattice_metadata"] = encode_metadata(self.lattice_metadata)
return json.dumps(data)
def deserialize(self, pickled_data: bytes) -> None:
"""
Load pickled representation of transport graph into the transport graph instance.
This overwrites anything currently set in the transport graph and deserializes
the base64 encoded cloudpickled function into a Callable.
Args:
pickled_data: Cloudpickled representation of the transport graph
Returns:
None
"""
node_link_data = cloudpickle.loads(pickled_data)
if "lattice_metadata" in node_link_data:
self.lattice_metadata = node_link_data["lattice_metadata"]
for idx, _ in enumerate(node_link_data["nodes"]):
function_ser = node_link_data["nodes"][idx].pop("function")
node_link_data["nodes"][idx]["function"] = TransportableObject.deserialize(
function_ser
)
self._graph = nx.readwrite.node_link_graph(node_link_data)
def deserialize_from_json(self, json_data: str) -> None:
"""Load JSON representation of transport graph into the transport graph instance.
This overwrites anything currently set in the transport
graph. Note that metadata (node and lattice-level) need to be
reconstituted from their dictionary representations when
needed.
Args:
json_data: JSON representation of the transport graph
Returns:
None
"""
node_link_data = json.loads(json_data)
if "lattice_metadata" in node_link_data:
self.lattice_metadata = node_link_data["lattice_metadata"]
for idx, node in enumerate(node_link_data["nodes"]):
function_ser = node_link_data["nodes"][idx].pop("function")
node_link_data["nodes"][idx]["function"] = TransportableObject.from_dict(function_ser)
if "value" in node:
node["value"] = TransportableObject.from_dict(node["value"])
self._graph = nx.readwrite.node_link_graph(node_link_data)