"""Class corresponding to computation workflow."""
import json
import os
import warnings
from builtins import list
from contextlib import redirect_stdout
from copy import deepcopy
from dataclasses import asdict
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
from .._shared_files import logger
from .._shared_files.context_managers import active_lattice_manager
from .._shared_files.defaults import DefaultMetadataValues
from .._shared_files.utils import get_named_params, get_serialized_function_str
from .depsbash import DepsBash
from .depscall import DepsCall
from .depspip import DepsPip
from .transport import TransportableObject, _TransportGraph, encode_metadata
if TYPE_CHECKING:
from .._results_manager.result import Result
from ..executor import BaseExecutor
from ..triggers import BaseTrigger
from .._shared_files.utils import get_imports, get_serialized_function_str
consumable_constraints = []
DEFAULT_METADATA_VALUES = asdict(DefaultMetadataValues())
app_log = logger.app_log
log_stack_info = logger.log_stack_info
class Lattice:
"""
A lattice workflow object that holds the work flow graph and is returned by :obj:`lattice <covalent.lattice>` decorator.
Attributes:
workflow_function: The workflow function that is decorated by :obj:`lattice <covalent.lattice>` decorator.
transport_graph: The transport graph which will be the basis on how the workflow is executed.
metadata: Dictionary of metadata of the lattice.
post_processing: Boolean to indicate if the lattice is in post processing mode or not.
kwargs: Keyword arguments passed to the workflow function.
electron_outputs: Dictionary of electron outputs received after workflow execution.
"""
def __init__(
self, workflow_function: Callable, transport_graph: _TransportGraph = None
) -> None:
self.workflow_function = workflow_function
self.workflow_function_string = get_serialized_function_str(self.workflow_function)
self.transport_graph = transport_graph or _TransportGraph()
self.metadata = {}
self.__name__ = self.workflow_function.__name__
self.__doc__ = self.workflow_function.__doc__
self.post_processing = False
self.args = []
self.kwargs = {}
self.named_args = {}
self.named_kwargs = {}
self.electron_outputs = {}
self.lattice_imports, self.cova_imports = get_imports(self.workflow_function)
self.cova_imports.update({"electron"})
self.workflow_function = TransportableObject.make_transportable(self.workflow_function)
def serialize_to_json(self) -> str:
attributes = deepcopy(self.__dict__)
attributes["workflow_function"] = self.workflow_function.to_dict()
attributes["metadata"] = encode_metadata(self.metadata)
attributes["transport_graph"] = None
if self.transport_graph:
attributes["transport_graph"] = self.transport_graph.serialize_to_json()
attributes["args"] = []
attributes["kwargs"] = {}
for arg in self.args:
attributes["args"].append(arg.to_dict())
for k, v in self.kwargs.items():
attributes["kwargs"][k] = v.to_dict()
for k, v in self.named_args.items():
attributes["named_args"][k] = v.to_dict()
for k, v in self.named_kwargs.items():
attributes["named_kwargs"][k] = v.to_dict()
attributes["electron_outputs"] = {}
for node_name, output in self.electron_outputs.items():
attributes["electron_outputs"][node_name] = output.to_dict()
attributes["cova_imports"] = list(self.cova_imports)
return json.dumps(attributes)
@staticmethod
def deserialize_from_json(json_data: str) -> None:
attributes = json.loads(json_data)
attributes["cova_imports"] = set(attributes["cova_imports"])
for node_name, object_dict in attributes["electron_outputs"].items():
attributes["electron_outputs"][node_name] = TransportableObject.from_dict(object_dict)
for k, v in attributes["named_kwargs"].items():
attributes["named_kwargs"][k] = TransportableObject.from_dict(v)
for k, v in attributes["named_args"].items():
attributes["named_args"][k] = TransportableObject.from_dict(v)
for k, v in attributes["kwargs"].items():
attributes["kwargs"][k] = TransportableObject.from_dict(v)
for i, arg in enumerate(attributes["args"]):
attributes["args"][i] = TransportableObject.from_dict(arg)
if attributes["transport_graph"]:
tg = _TransportGraph()
tg.deserialize_from_json(attributes["transport_graph"])
attributes["transport_graph"] = tg
attributes["workflow_function"] = TransportableObject.from_dict(
attributes["workflow_function"]
)
def dummy_function(x):
return x
lat = Lattice(dummy_function)
lat.__dict__ = attributes
return lat
def set_metadata(self, name: str, value: Any) -> None:
"""
Function to add/edit metadata of given name and value
to lattice's metadata.
Args:
name: Name of the metadata to be added/edited.
value: Value of the metadata to be added/edited.
Returns:
None
"""
self.metadata[name] = value
def get_metadata(self, name: str) -> Any:
"""
Get value of the metadata of given name.
Args:
name: Name of the metadata whose value is needed.
Returns:
value: Value of the metadata of given name.
Raises:
KeyError: If metadata of given name is not present.
"""
return self.metadata.get(name, None)
def build_graph(self, *args, **kwargs) -> None:
"""
Builds the transport graph for the lattice by executing the workflow
function which will trigger the call of all underlying electrons and
they will get added to the transport graph for later execution.
Also redirects any print statements inside the lattice function to null
and ignores any exceptions caused while executing the function.
GRAPH WILL NOT BE BUILT AFTER AN EXCEPTION HAS OCCURRED.
Args:
*args: Positional arguments to be passed to the workflow function.
**kwargs: Keyword arguments to be passed to the workflow function.
Returns:
None
"""
self.args = [TransportableObject.make_transportable(arg) for arg in args]
self.kwargs = {k: TransportableObject.make_transportable(v) for k, v in kwargs.items()}
self.transport_graph.reset()
workflow_function = self.workflow_function.get_deserialized()
named_args, named_kwargs = get_named_params(workflow_function, self.args, self.kwargs)
self.named_args = named_args
self.named_kwargs = named_kwargs
new_args = [v.get_deserialized() for _, v in named_args.items()]
new_kwargs = {k: v.get_deserialized() for k, v in named_kwargs.items()}
constraint_names = {"executor", "workflow_executor", "deps", "call_before", "call_after"}
new_metadata = {
name: DEFAULT_METADATA_VALUES[name]
for name in constraint_names
if not self.metadata[name]
}
new_metadata = encode_metadata(new_metadata)
for k, v in new_metadata.items():
self.metadata[k] = v
with redirect_stdout(open(os.devnull, "w")):
with active_lattice_manager.claim(self):
try:
workflow_function(*new_args, **new_kwargs)
except Exception:
warnings.warn(
"Please make sure you are not manipulating an object inside the lattice."
)
raise
def draw(self, *args, **kwargs) -> None:
"""
Generate lattice graph and display in UI taking into account passed in
arguments.
Args:
*args: Positional arguments to be passed to build the graph.
**kwargs: Keyword arguments to be passed to build the graph.
Returns:
None
"""
import covalent_ui.result_webhook as result_webhook
self.build_graph(*args, **kwargs)
result_webhook.send_draw_request(self)
def __call__(self, *args, **kwargs):
"""Execute lattice as an ordinary function for testing purposes."""
workflow_function = self.workflow_function.get_deserialized()
return workflow_function(*args, **kwargs)
def dispatch(self, *args, **kwargs) -> str:
"""
DEPRECATED: Function to dispatch workflows.
Args:
*args: Positional arguments for the workflow
**kwargs: Keyword arguments for the workflow
Returns:
Dispatch id assigned to job
"""
app_log.warning(
"workflow.dispatch(your_arguments_here) is deprecated and may get removed without notice in future releases. Please use covalent.dispatch(workflow)(your_arguments_here) instead.",
exc_info=DeprecationWarning,
)
from .._dispatcher_plugins import local_dispatch
return local_dispatch(self)(*args, **kwargs)
def dispatch_sync(self, *args, **kwargs) -> "Result":
"""
DEPRECATED: Function to dispatch workflows synchronously by waiting for the result too.
Args:
*args: Positional arguments for the workflow
**kwargs: Keyword arguments for the workflow
Returns:
Result of workflow execution
"""
app_log.warning(
"workflow.dispatch_sync(your_arguments_here) is deprecated and may get removed without notice in future releases. Please use covalent.dispatch_sync(workflow)(your_arguments_here) instead.",
exc_info=DeprecationWarning,
)
from .._dispatcher_plugins import local_dispatch_sync
return local_dispatch_sync(self)(*args, **kwargs)
def lattice(
_func: Optional[Callable] = None,
*,
backend: Optional[str] = None,
executor: Optional[Union[List[Union[str, "BaseExecutor"]], Union[str, "BaseExecutor"]]] = None,
workflow_executor: Optional[
Union[List[Union[str, "BaseExecutor"]], Union[str, "BaseExecutor"]]
] = None,
deps_bash: Union[DepsBash, list, str] = None,
deps_pip: Union[DepsPip, list] = None,
call_before: Union[List[DepsCall], DepsCall] = [],
call_after: Union[List[DepsCall], DepsCall] = [],
triggers: Union["BaseTrigger", List["BaseTrigger"]] = None,
) -> Lattice:
"""
Lattice decorator to be called upon a function. Returns a new `Lattice <covalent._workflow.lattice.Lattice>` object.
Args:
_func: function to be decorated
Keyword Args:
backend: DEPRECATED: Same as `executor`.
executor: Alternative executor object to be used in the execution of each node. If not passed, the local
executor is used by default.
workflow_executor: Executor for postprocessing the workflow. Defaults to the built-in dask executor or
the local executor depending on whether Covalent is started with the `--no-cluster` option.
deps_bash: An optional DepsBash object specifying a list of shell commands to run before `_func`
deps_pip: An optional DepsPip object specifying a list of PyPI packages to install before running `_func`
call_before: An optional list of DepsCall objects specifying python functions to invoke before the electron
call_after: An optional list of DepsCall objects specifying python functions to invoke after the electron
triggers: Any triggers that need to be attached to this lattice, default is None
Returns:
:obj:`Lattice <covalent._workflow.lattice.Lattice>` : Lattice object inside which the decorated function exists.
"""
if backend:
app_log.warning(
"backend is deprecated and will be removed in a future release. Please use executor keyword instead.",
exc_info=DeprecationWarning,
)
executor = backend
deps = {}
if isinstance(deps_bash, DepsBash):
deps["bash"] = deps_bash
if isinstance(deps_bash, (list, str)):
deps["bash"] = DepsBash(commands=deps_bash)
if isinstance(deps_pip, DepsPip):
deps["pip"] = deps_pip
if isinstance(deps_pip, list):
deps["pip"] = DepsPip(packages=deps_pip)
if isinstance(call_before, DepsCall):
call_before = [call_before]
if isinstance(call_after, DepsCall):
call_after = [call_after]
from ..triggers import BaseTrigger
if isinstance(triggers, BaseTrigger):
triggers = [triggers]
constraints = {
"executor": executor,
"workflow_executor": workflow_executor,
"deps": deps,
"call_before": call_before,
"call_after": call_after,
"triggers": triggers,
}
constraints = encode_metadata(constraints)
def decorator_lattice(func=None):
@wraps(func)
def wrapper_lattice(*args, **kwargs):
lattice_object = Lattice(workflow_function=func)
for k, v in constraints.items():
lattice_object.set_metadata(k, v)
lattice_object.transport_graph.lattice_metadata = lattice_object.metadata
return lattice_object
return wrapper_lattice()
if _func is None:
return decorator_lattice
else:
return decorator_lattice(_func)