diff --git a/src/qutip_qoc/objective.py b/src/qutip_qoc/objective.py index d91caac..986961a 100644 --- a/src/qutip_qoc/objective.py +++ b/src/qutip_qoc/objective.py @@ -80,12 +80,6 @@ def __init__(self, initial, H, target, weight=1): self.target = target self.weight = weight - # Check if any Hamiltonian in H is a superoperator - if any(qt.issuper(H_i) for H_i in (H if isinstance(H, list) else [H])): - # Convert initial and target accordingly - self.initial = qt.to_super(self.initial) - self.target = qt.to_super(self.target) - def __getstate__(self): """ Extract picklable information from the objective. diff --git a/src/qutip_qoc/pulse_optim.py b/src/qutip_qoc/pulse_optim.py index 94d424c..935bf18 100644 --- a/src/qutip_qoc/pulse_optim.py +++ b/src/qutip_qoc/pulse_optim.py @@ -11,6 +11,8 @@ from qutip_qoc._optimizer import _global_local_optimization from qutip_qoc._time import _TimeInterval +import qutip as qt + try: from qutip_qoc._rl import _RL _rl_available = True @@ -28,6 +30,7 @@ def optimize_pulses( optimizer_kwargs=None, minimizer_kwargs=None, integrator_kwargs=None, + optimization_type=None, ): """ Run GOAT, JOPT, GRAPE, CRAB or RL optimization. @@ -124,6 +127,11 @@ def optimize_pulses( Options for the solver, see :obj:`MESolver.options` and `Integrator <./classes.html#classes-ode>`_ for a list of all options. + optimization_type : str, optional + Type of optimization. By default, QuTiP-QOC will try to automatically determine + whether this is a *state transfer* or a *gate synthesis* problem. Set this + flag to ``"state_transfer"`` or ``"gate_synthesis"`` to set the mode manually. + Returns ------- result : :class:`qutip_qoc.Result` @@ -187,10 +195,43 @@ def optimize_pulses( "maxiter": algorithm_kwargs.get("max_iter", 1000), "gtol": algorithm_kwargs.get("min_grad", 0.0 if alg == "CRAB" else 1e-8), } + # Iterate over objectives and convert initial and target states based on the optimization type + for objective in objectives: + H_list = objective.H if isinstance(objective.H, list) else [objective.H] + if any(qt.issuper(H_i) for H_i in H_list): + if isinstance(optimization_type, str) and optimization_type.lower() == "state_transfer": + if qt.isket(objective.initial): + objective.initial = qt.operator_to_vector(qt.ket2dm(objective.initial)) + elif qt.isoper(objective.initial): + objective.initial = qt.operator_to_vector(objective.initial) + if qt.isket(objective.target): + objective.target = qt.operator_to_vector(qt.ket2dm(objective.target)) + elif qt.isoper(objective.target): + objective.target = qt.operator_to_vector(objective.target) + elif isinstance(optimization_type, str) and optimization_type.lower() == "gate_synthesis": + objective.initial = qt.to_super(objective.initial) + objective.target = qt.to_super(objective.target) + elif optimization_type is None: + if qt.isoper(objective.initial) and qt.isoper(objective.target): + if np.isclose((objective.initial).tr(), 1) and np.isclose((objective.target).tr(), 1): + objective.initial = qt.operator_to_vector(objective.initial) + objective.target = qt.operator_to_vector(objective.target) + else: + objective.initial = qt.to_super(objective.initial) + objective.target = qt.to_super(objective.target) + if qt.isket(objective.initial): + objective.initial = qt.operator_to_vector(qt.ket2dm(objective.initial)) + if qt.isket(objective.target): + objective.target = qt.operator_to_vector(qt.ket2dm(objective.target)) # prepare qtrl optimizers qtrl_optimizers = [] if alg == "CRAB" or alg == "GRAPE": + dyn_type = "GEN_MAT" + for objective in objectives: + if any(qt.isoper(H_i) for H_i in (objective.H if isinstance(objective.H, list) else [objective.H])): + dyn_type = "UNIT" + if alg == "GRAPE": # algorithm specific kwargs use_as_amps = True minimizer_kwargs.setdefault("method", "L-BFGS-B") # gradient @@ -247,7 +288,7 @@ def optimize_pulses( "accuracy_factor": None, # deprecated "alg_params": alg_params, "optim_params": algorithm_kwargs.get("optim_params", None), - "dyn_type": algorithm_kwargs.get("dyn_type", "GEN_MAT"), + "dyn_type": algorithm_kwargs.get("dyn_type", dyn_type), "dyn_params": algorithm_kwargs.get("dyn_params", None), "prop_type": algorithm_kwargs.get( "prop_type", "DEF"