diff --git a/xrft/tests/test_xrft.py b/xrft/tests/test_xrft.py
index dc661879..7b6a9301 100644
--- a/xrft/tests/test_xrft.py
+++ b/xrft/tests/test_xrft.py
@@ -12,6 +12,7 @@
 import xarray.testing as xrt
 
 import xrft
+from ..xrft import _apply_window
 
 
 @pytest.fixture()
@@ -524,13 +525,30 @@ def test_cross_spectrum(self, dask):
         cs = xrft.cross_spectrum(
             da, da2, dim=dim, shift=True, window="hann", detrend="constant"
         )
-        test = (daft * np.conj(daft2)).values / N ** 4
+        test = (daft * np.conj(daft2)) / N ** 4
 
         dk = np.diff(np.fft.fftfreq(N, 1.0))[0]
         test /= dk ** 2
         npt.assert_almost_equal(cs.values, test)
         npt.assert_almost_equal(np.ma.masked_invalid(cs).mask.sum(), 0.0)
 
+        cs = xrft.cross_spectrum(
+            da,
+            da2,
+            dim=dim,
+            shift=True,
+            window="hann",
+            detrend="constant",
+            window_correction=True,
+        )
+        test = (daft * np.conj(daft2)) / N ** 4
+        window, _ = _apply_window(da, dim, window_type="hann")
+        dk = np.diff(np.fft.fftfreq(N, 1.0))[0]
+        test /= dk ** 2 * (window ** 2).mean()
+
+        npt.assert_almost_equal(cs.values, test)
+        npt.assert_almost_equal(np.ma.masked_invalid(cs).mask.sum(), 0.0)
+
         with pytest.raises(ValueError):
             xrft.cross_spectrum(da, da2, dim=dim, window=None, window_correction=True)
 
diff --git a/xrft/xrft.py b/xrft/xrft.py
index fb7bc622..b5abb8f5 100644
--- a/xrft/xrft.py
+++ b/xrft/xrft.py
@@ -863,7 +863,7 @@ def cross_spectrum(
                     "window_correction can only be applied when windowing is turned on."
                 )
             else:
-                windows, _ = _apply_window(da, dim, window_type=kwargs.get("window"))
+                windows, _ = _apply_window(da1, dim, window_type=kwargs.get("window"))
                 cs = cs / (windows ** 2).mean()
         fs = np.prod([float(cs[d].spacing) for d in updated_dims])
         cs *= fs
@@ -874,7 +874,7 @@ def cross_spectrum(
                     "window_correction can only be applied when windowing is turned on."
                 )
             else:
-                windows, _ = _apply_window(da, dim, window_type=kwargs.get("window"))
+                windows, _ = _apply_window(da1, dim, window_type=kwargs.get("window"))
                 cs = cs / windows.mean() ** 2
         fs = np.prod([float(cs[d].spacing) for d in updated_dims])
         cs *= fs ** 2