diff --git a/src/torchsim/models/mpnrage.py b/src/torchsim/models/mpnrage.py index 9e7ac3b..9088767 100644 --- a/src/torchsim/models/mpnrage.py +++ b/src/torchsim/models/mpnrage.py @@ -79,6 +79,8 @@ def set_sequence( nshots: int, flip: float, TR: float, + MPRAGE_TR: float | None = None, + num_inversions: int = 1, TI: float = 0.0, slice_prof: float | npt.ArrayLike = 1.0, ): @@ -96,6 +98,10 @@ def set_sequence( TI : float, optional Inversion time in milliseconds. The default is ``0.0``. + MPRAGE_TR : float default is None + Repetition time in milliseconds for the whole inversion block. + num_inversions : int, optional + Number of inversion pulses, default is ``1``. slice_prof : float | npt.ArrayLike, optional Flip angle scaling along slice profile. The default is ``1.0``. @@ -106,6 +112,8 @@ def set_sequence( self.sequence.TR = TR * 1e-3 # ms -> s self.sequence.TI = TI * 1e-3 # ms -> s self.sequence.slice_prof = slice_prof + self.sequence.num_inversions = num_inversions + self.sequence.MPRAGE_TR = MPRAGE_TR * 1e-3 if MPRAGE_TR is not None else None @staticmethod def _engine( @@ -118,6 +126,8 @@ def _engine( B1: float | npt.ArrayLike = 1.0, inv_efficiency: float | npt.ArrayLike = 1.0, slice_prof: float | npt.ArrayLike = 1.0, + num_inversions: int = 1, + MPRAGE_TR: float = None, ): # Prepare relaxation parameters R1 = 1e3 / T1 @@ -136,26 +146,29 @@ def _engine( # Prepare relaxation operator for sequence loop E1, rE1 = epg.longitudinal_relaxation_op(R1, TR) - + if MPRAGE_TR is not None: + mprageE1, mpragerE1 = epg.longitudinal_relaxation_op(R1, MPRAGE_TR) # Initialize signal signal = [] + for i in range(num_inversions): + # Apply inversion + states = epg.adiabatic_inversion(states, inv_efficiency) + states = epg.longitudinal_relaxation(states, E1inv, rE1inv) + states = epg.spoil(states) - # Apply inversion - states = epg.adiabatic_inversion(states, inv_efficiency) - states = epg.longitudinal_relaxation(states, E1inv, rE1inv) - states = epg.spoil(states) - - # Scan loop - for p in range(nshots): + # Scan loop + for p in range(nshots): - # Apply RF pulse - states = epg.rf_pulse(states, RF) + # Apply RF pulse + states = epg.rf_pulse(states, RF) - # Record signal - signal.append(epg.get_signal(states)) + # Record signal + signal.append(epg.get_signal(states)) - # Evolve - states = epg.longitudinal_relaxation(states, E1, rE1) - states = epg.spoil(states) + # Evolve + states = epg.longitudinal_relaxation(states, E1, rE1) + states = epg.spoil(states) + if MPRAGE_TR is not None: + epg.longitudinal_relaxation(states, mprageE1, mpragerE1) return M0 * 1j * torch.stack(signal) diff --git a/src/torchsim/models/mprage.py b/src/torchsim/models/mprage.py index b63b380..cf87338 100644 --- a/src/torchsim/models/mprage.py +++ b/src/torchsim/models/mprage.py @@ -76,6 +76,8 @@ def set_sequence( flip: float, TRspgr: float, nshots: int | npt.ArrayLike, + TRmprage: float = None, + num_inversions: int = 1, ): """ Set sequence parameters for the SPGR model. @@ -88,13 +90,15 @@ def set_sequence( Flip angle train in degrees. TRspgr : float Repetition time in milliseconds for each SPGR readout. - TRmprage : float + TRmprage : float default is None Repetition time in milliseconds for the whole inversion block. nshots : int | npt.ArrayLike Number of SPGR readout within the inversion block of shape ``(npre, npost)`` If scalar, assume ``npre == npost == 0.5 * nshots``. Usually, this is the number of slice encoding lines ``(nshots = nz / Rz)``, i.e., the number of slices divided by the total acceleration factor along ``z``. + num_inversions : int, optional + Number of inversion pulses, default is ``1``. """ self.sequence.nshots = nshots @@ -104,6 +108,10 @@ def set_sequence( if nshots.numel() == 1: nshots = torch.repeat_interleave(nshots // 2, 2) self.sequence.nshots = nshots + self.sequence.TRmprage = TRmprage * 1e-3 # ms -> s + if TRmprage is None and num_inversions > 1: + raise ValueError("TRmprage must be provided for multiple inversions") + self.sequence.num_inversions = num_inversions @staticmethod def _engine( @@ -115,6 +123,7 @@ def _engine( nshots: int | npt.ArrayLike, M0: float | npt.ArrayLike = 1.0, inv_efficiency: float | npt.ArrayLike = 1.0, + num_inversions: int = 1, ): R1 = 1e3 / T1 @@ -135,21 +144,26 @@ def _engine( # Prepare relaxation operator for sequence loop E1, rE1 = epg.longitudinal_relaxation_op(R1, TRspgr) + if TRmprage is not None: + mprageE1, mpragerE1 = epg.longitudinal_relaxation_op(R1, TRmprage) - # Apply inversion - states = epg.adiabatic_inversion(states, inv_efficiency) - states = epg.longitudinal_relaxation(states, E1inv, rE1inv) - states = epg.spoil(states) - + signal = [] # Scan loop - for p in range(nshots_bef): - - # Apply RF pulse - states = epg.rf_pulse(states, RF) - - # Evolve - states = epg.longitudinal_relaxation(states, E1, rE1) + for i in range(num_inversions): + # Apply inversion + states = epg.adiabatic_inversion(states, inv_efficiency) + states = epg.longitudinal_relaxation(states, E1inv, rE1inv) states = epg.spoil(states) - + + for p in range(nshots_bef*2): + + # Apply RF pulse + states = epg.rf_pulse(states, RF) + # Evolve + states = epg.longitudinal_relaxation(states, E1, rE1) + signal.append(M0 * 1j * epg.get_signal(states)) + states = epg.spoil(states) + if TRmprage is not None: + epg.longitudinal_relaxation(states, mprageE1, mpragerE1) # Record signal - return M0 * 1j * epg.get_signal(states) + return torch.stack(signal)