from typing import Optional
import pennylane as qml
from pydantic import Field
from braket.aws import AwsQuantumTask, AwsQuantumTaskBatch
from covalent._shared_files.config import get_config
from covalent.experimental.covalent_qelectron.executors.base import (
BaseProcessPoolQExecutor,
BaseThreadPoolQExecutor,
QCResult,
get_thread_pool,
get_process_pool
)
__all__ = [
"BraketQubitExecutor",
"LocalBraketQubitExecutor"
]
_QEXECUTOR_PLUGIN_DEFAULTS = {
"BraketQubitExecutor": {
"device_arn": "",
"s3_destination_folder": "",
"poll_timeout_seconds": AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
"poll_interval_seconds": AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
"aws_session": "",
"parallel": False,
"max_parallel": None,
"max_connections": AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT,
"max_retries": AwsQuantumTaskBatch.MAX_RETRIES,
"run_kwargs": {},
"max_jobs": 20
},
"LocalBraketQubitExecutor": {
"backend": "default",
"shots": None,
"run_kwargs": {},
"max_jobs": 20
}
}
[docs]
class BraketQubitExecutor(BaseThreadPoolQExecutor):
"""
The remote Braket executor based on the existing Pennylane Braket
qubit device. Usage of this device requires valid AWS credentials as
set up following the instructions at
https://github.com/aws/amazon-braket-sdk-python#prerequisites.
Attributes:
max_jobs:
maximum number of parallel jobs sent by threads on :code:`batch_submit`.
shots: number of shots used to estimate quantum observables.
device_arn:
an alpha-numeric code (arn=Amazon Resource Name) specifying a quantum device.
poll_timeout_seconds:
number of seconds before a poll to remote device is considered timed-out.
poll_interval_seconds:
number of seconds between polling of a remote device's status.
aws_session:
An :code:`AwsSession` object created to manage interactions with AWS services,
to be supplied if extra control is desired.
parallel: turn parallel execution on or off.
max_parallel: the maximum number of circuits to be executed in parallel.
max_connections: the maximum number of connections in the :code:`Boto3` connection pool.
max_retries: the maximum number of time a job will be re-sent if it failed
s3_destination_folder: Name of the S3 bucket and folder, specified as a tuple.
run_kwargs: Variable length keyword arguments for :code:`braket.devices.Device.run()`
"""
max_jobs: int = 20
shots: int = None,
device_arn: str = None
poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT
poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL
aws_session: Optional[str] = None
parallel: bool = False
max_parallel: Optional[int] = None
max_connections: int = AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT
max_retries: int = AwsQuantumTaskBatch.MAX_RETRIES
s3_destination_folder: tuple = Field(
default_factory=lambda: get_config(
"qelectron")["BraketQubitExecutor"]["s3_destination_folder"]
)
run_kwargs: dict = {}
[docs]
def batch_submit(self, qscripts_list):
"""
Submit qscripts for execution using :code:`max_jobs`-many threads.
Args:
qscripts_list: a list of Pennylane style :code:`QuantumScripts`
Returns:
jobs: a :code:`list` of tasks subitted by threads.
"""
device_shots = self.shots if self.shots != 0 else self.qnode_device_shots
p = get_thread_pool(self.max_jobs)
jobs = []
for qscript in qscripts_list:
dev = qml.device(
"braket.aws.qubit",
wires=qscript.wires,
device_arn=self.device_arn,
s3_destination_folder=self.s3_destination_folder,
shots=device_shots,
poll_timeout_seconds=self.poll_timeout_seconds,
poll_interval_seconds=self.poll_interval_seconds,
aws_session=self.aws_session,
parallel=self.parallel,
max_parallel=self.max_parallel,
max_connections=self.max_connections,
max_retries=self.max_retries,
**self.run_kwargs
)
result_obj = QCResult.with_metadata(
device_name=dev.short_name,
executor=self,
)
jobs.append(p.submit(self.run_circuit, qscript, dev, result_obj))
return jobs
[docs]
def dict(self, *args, **kwargs):
dict_ = super().dict(*args, **kwargs)
dict_["run_kwargs"] = tuple(dict_["run_kwargs"].items())
return dict_
[docs]
class LocalBraketQubitExecutor(BaseProcessPoolQExecutor):
"""
The local Braket executor based on the existing Pennylane local Braket qubit device.
Attributes:
max_jobs: maximum number of parallel jobs sent by processes on :code:`batch_submit`.
shots: number of shots used to estimate quantum observables.
backend:
The name of the simulator backend. Defaults to the :code:`"default"`
simulator backend name.
run_kwargs: Variable length keyword arguments for :code:`braket.devices.Device.run()`.
"""
max_jobs: int = 20
shots: int = None
backend: str = Field(
default_factory=lambda: get_config("qelectron")["LocalBraketQubitExecutor"]["backend"]
)
run_kwargs: dict = {}
[docs]
def batch_submit(self, qscripts_list):
"""
Submit qscripts for execution using :code:`num_processes`-many processes.
Args:
qscripts_list: a list of Pennylane style :code:`QuantumScripts`.
Returns:
jobs: a :code:`list` of :code:`futures` subitted by processes.
"""
device_shots = self.shots if self.shots != 0 else self.qnode_device_shots
pool = get_process_pool(self.num_processes)
futures = []
for qscript in qscripts_list:
dev = qml.device(
"braket.local.qubit",
wires=qscript.wires,
backend=self.backend,
shots=device_shots,
**self.run_kwargs
)
result_obj = QCResult.with_metadata(
device_name=dev.short_name,
executor=self,
)
fut = pool.apply_async(self.run_circuit, args=(qscript, dev, result_obj))
futures.append(fut)
return futures
[docs]
def dict(self, *args, **kwargs):
dict_ = super().dict(*args, **kwargs)
dict_["run_kwargs"] = tuple(dict_["run_kwargs"].items())
return dict_