import base64
from typing import Callable, Union
from ..shared_utils import cloudpickle_deserialize, cloudpickle_serialize
from .base import AsyncBaseQCluster
from .default_selectors import selector_map
__all__ = [
"QCluster",
]
class QCluster(AsyncBaseQCluster):
"""
A cluster of quantum executors.
Args:
executors: A sequence of quantum executors.
selector: A callable that selects an executor, or one of the strings "cyclic"
or "random". The "cyclic" selector (default) cycles through `executors`
and returns the next executor for each circuit. The "random" selector
chooses an executor from `executors` at random for each circuit. Any
user-defined selector must be callable with two positional arguments,
a circuit and a list of executors. A selector must also return exactly
one executor.
"""
selector: Union[str, Callable] = "cyclic"
_selector_serialized: bool = False
def batch_submit(self, qscripts_list):
if self._selector_serialized:
self.selector = self.deserialize_selector()
selector = self.get_selector()
selected_executor = selector(qscripts_list, self.executors)
selected_executor.qnode_device_import_path = self.qnode_device_import_path
selected_executor.qnode_device_shots = self.qnode_device_shots
selected_executor.qnode_device_wires = self.qnode_device_wires
selected_executor.pennylane_active_return = self.pennylane_active_return
return selected_executor.batch_submit(qscripts_list)
def serialize_selector(self) -> None:
if self._selector_serialized:
return
self.selector = cloudpickle_serialize(self.selector)
self.selector = base64.b64encode(self.selector).decode("utf-8")
self._selector_serialized = True
def deserialize_selector(self) -> Union[str, Callable]:
if not self._selector_serialized:
return self.selector
selector = cloudpickle_deserialize(
base64.b64decode(self.selector.encode("utf-8"))
)
self._selector_serialized = False
return selector
def dict(self, *args, **kwargs) -> dict:
dict_ = super(AsyncBaseQCluster, self).dict(*args, **kwargs)
dict_.update(executors=tuple(ex.json() for ex in self.executors))
return dict_
def get_selector(self) -> Callable:
"""
Wraps `self.selector` to return defaults corresponding to string values.
This method is called inside `batch_submit`.
"""
self.selector = self.deserialize_selector()
if isinstance(self.selector, str):
selector_cls = selector_map[self.selector]
self.selector = selector_cls()
return self.selector