From 2cd03328abdcc8431ccd8fe23cd5939dcd7af13c Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Thu, 20 Feb 2025 06:00:59 -0800 Subject: [PATCH] [XLA:GPU][Emitters] Restrict the inliner. At the moment we inline a callee that calls one of the functions called by the caller. This PR adjusts the logic to inline only such callees that call a subset of the caller's functions. That way we can be sure that after inlining the caller calls the same set of functions excluding the inlined one. Background: https://github.com/jax-ml/jax/issues/26162 contains an example of a MoF fusion that takes forever to compile. The [indexing-based partitioner](https://github.com/openxla/xla/commit/44bc81687d83a328b9212f6ecf3e6c18a79bd5fd) in combination with this change fix the issue. PiperOrigin-RevId: 729079659 --- .../emitters/tests/loop/mof_reshapes.hlo | 332 ++++++++++++++++++ xla/codegen/emitters/ir/xla_dialect.cc | 24 +- 2 files changed, 345 insertions(+), 11 deletions(-) create mode 100644 xla/backends/gpu/codegen/emitters/tests/loop/mof_reshapes.hlo diff --git a/xla/backends/gpu/codegen/emitters/tests/loop/mof_reshapes.hlo b/xla/backends/gpu/codegen/emitters/tests/loop/mof_reshapes.hlo new file mode 100644 index 00000000000000..ea30e5c48a1c89 --- /dev/null +++ b/xla/backends/gpu/codegen/emitters/tests/loop/mof_reshapes.hlo @@ -0,0 +1,332 @@ +// RUN: test_correctness %s +// https://github.com/jax-ml/jax/issues/26162 + +HloModule m +f { + p0 = f32[4,8192,1]{2,1,0} parameter(0) + p1 = f32[4,8192,1]{2,1,0} parameter(1) + p2 = f32[4,8192,1]{2,1,0} parameter(2) + p3 = f32[8192,1]{1,0} parameter(3) + p4 = f32[8192,1]{1,0} parameter(4) + p5 = f32[8192,1]{1,0} parameter(5) + p6 = f32[8192,1]{1,0} parameter(6) + constant_64801_24_clone_1 = f32[] constant(0.707106769) + broadcast.81498.40.clone.1 = f32[8192,1]{1,0} broadcast(constant_64801_24_clone_1), dimensions={} + multiply.60143.5.clone.1 = f32[8192,1]{1,0} multiply(p6, broadcast.81498.40.clone.1) + constant_64802_11_clone_1 = f32[] constant(-0.707106769) + broadcast.81521.32.clone.1 = f32[8192,1]{1,0} broadcast(constant_64802_11_clone_1), dimensions={} + multiply.60144.5.clone.1 = f32[8192,1]{1,0} multiply(p5, broadcast.81521.32.clone.1) + subtract.27368.3.clone.1 = f32[8192,1]{1,0} subtract(multiply.60143.5.clone.1, multiply.60144.5.clone.1) + constant_64743_277_clone_1 = f32[] constant(0) + broadcast.81461.143.clone.1 = f32[8192,1]{1,0} broadcast(constant_64743_277_clone_1), dimensions={} + multiply.60145.9.clone.1 = f32[8192,1]{1,0} multiply(p4, broadcast.81461.143.clone.1) + subtract.27369.3.clone.1 = f32[8192,1]{1,0} subtract(subtract.27368.3.clone.1, multiply.60145.9.clone.1) + multiply.60146.5.clone.1 = f32[8192,1]{1,0} multiply(p3, broadcast.81461.143.clone.1) + subtract.27370.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27369.3.clone.1, multiply.60146.5.clone.1) + slice.39304.2.clone.1 = f32[1,8192,1]{2,1,0} slice(p2), slice={[0:1], [0:8192], [0:1]} + bitcast.136254.22.clone.1 = f32[8192,1]{1,0} bitcast(slice.39304.2.clone.1) + multiply.60149.7.clone.1 = f32[8192,1]{1,0} multiply(subtract.27370.1.clone.1, bitcast.136254.22.clone.1) + multiply.60150.5.clone.1 = f32[8192,1]{1,0} multiply(p6, broadcast.81521.32.clone.1) + multiply.60151.5.clone.1 = f32[8192,1]{1,0} multiply(p5, broadcast.81498.40.clone.1) + add.56861.3.clone.1 = f32[8192,1]{1,0} add(multiply.60150.5.clone.1, multiply.60151.5.clone.1) + add.56862.3.clone.1 = f32[8192,1]{1,0} add(add.56861.3.clone.1, multiply.60145.9.clone.1) + subtract.27371.1.clone.1 = f32[8192,1]{1,0} subtract(add.56862.3.clone.1, multiply.60146.5.clone.1) + slice.39305.8.clone.1 = f32[1,8192,1]{2,1,0} slice(p2), slice={[1:2], [0:8192], [0:1]} + bitcast.136279.14.clone.1 = f32[8192,1]{1,0} bitcast(slice.39305.8.clone.1) + multiply.60152.3.clone.1 = f32[8192,1]{1,0} multiply(subtract.27371.1.clone.1, bitcast.136279.14.clone.1) + subtract.27372.5.clone.1 = f32[8192,1]{1,0} subtract(multiply.60149.7.clone.1, multiply.60152.3.clone.1) + multiply.60153.5.clone.1 = f32[8192,1]{1,0} multiply(p6, broadcast.81461.143.clone.1) + multiply.60154.11.clone.1 = f32[8192,1]{1,0} multiply(p5, broadcast.81461.143.clone.1) + subtract.27373.5.clone.1 = f32[8192,1]{1,0} subtract(multiply.60153.5.clone.1, multiply.60154.11.clone.1) + multiply.60155.3.clone.1 = f32[8192,1]{1,0} multiply(p4, broadcast.81498.40.clone.1) + add.56863.3.clone.1 = f32[8192,1]{1,0} add(subtract.27373.5.clone.1, multiply.60155.3.clone.1) + multiply.60156.3.clone.1 = f32[8192,1]{1,0} multiply(p3, broadcast.81521.32.clone.1) + add.56864.1.clone.1 = f32[8192,1]{1,0} add(add.56863.3.clone.1, multiply.60156.3.clone.1) + slice.39306.12.clone.1 = f32[1,8192,1]{2,1,0} slice(p2), slice={[2:3], [0:8192], [0:1]} + bitcast.136313.16.clone.1 = f32[8192,1]{1,0} bitcast(slice.39306.12.clone.1) + multiply.60157.3.clone.1 = f32[8192,1]{1,0} multiply(add.56864.1.clone.1, bitcast.136313.16.clone.1) + subtract.27374.3.clone.1 = f32[8192,1]{1,0} subtract(subtract.27372.5.clone.1, multiply.60157.3.clone.1) + add.56865.5.clone.1 = f32[8192,1]{1,0} add(multiply.60153.5.clone.1, multiply.60154.11.clone.1) + multiply.60158.3.clone.1 = f32[8192,1]{1,0} multiply(p4, broadcast.81521.32.clone.1) + subtract.27375.3.clone.1 = f32[8192,1]{1,0} subtract(add.56865.5.clone.1, multiply.60158.3.clone.1) + multiply.60159.3.clone.1 = f32[8192,1]{1,0} multiply(p3, broadcast.81498.40.clone.1) + add.56866.1.clone.1 = f32[8192,1]{1,0} add(subtract.27375.3.clone.1, multiply.60159.3.clone.1) + slice.39308.12.clone.1 = f32[1,8192,1]{2,1,0} slice(p2), slice={[3:4], [0:8192], [0:1]} + bitcast.136339.16.clone.1 = f32[8192,1]{1,0} bitcast(slice.39308.12.clone.1) + multiply.60160.3.clone.1 = f32[8192,1]{1,0} multiply(add.56866.1.clone.1, bitcast.136339.16.clone.1) + subtract.27376.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27374.3.clone.1, multiply.60160.3.clone.1) + multiply.60173.9.clone.1 = f32[8192,1]{1,0} multiply(subtract.27376.1.clone.1, broadcast.81498.40.clone.1) + multiply.60161.3.clone.1 = f32[8192,1]{1,0} multiply(subtract.27370.1.clone.1, bitcast.136279.14.clone.1) + multiply.60162.7.clone.1 = f32[8192,1]{1,0} multiply(subtract.27371.1.clone.1, bitcast.136254.22.clone.1) + add.56867.5.clone.1 = f32[8192,1]{1,0} add(multiply.60161.3.clone.1, multiply.60162.7.clone.1) + multiply.60163.3.clone.1 = f32[8192,1]{1,0} multiply(add.56864.1.clone.1, bitcast.136339.16.clone.1) + add.56868.3.clone.1 = f32[8192,1]{1,0} add(add.56867.5.clone.1, multiply.60163.3.clone.1) + multiply.60164.3.clone.1 = f32[8192,1]{1,0} multiply(add.56866.1.clone.1, bitcast.136313.16.clone.1) + subtract.27377.1.clone.1 = f32[8192,1]{1,0} subtract(add.56868.3.clone.1, multiply.60164.3.clone.1) + multiply.60174.9.clone.1 = f32[8192,1]{1,0} multiply(subtract.27377.1.clone.1, broadcast.81498.40.clone.1) + subtract.27380.5.clone.1 = f32[8192,1]{1,0} subtract(multiply.60173.9.clone.1, multiply.60174.9.clone.1) + multiply.60165.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27370.1.clone.1, bitcast.136313.16.clone.1) + multiply.60166.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27371.1.clone.1, bitcast.136339.16.clone.1) + subtract.27378.3.clone.1 = f32[8192,1]{1,0} subtract(multiply.60165.5.clone.1, multiply.60166.5.clone.1) + multiply.60167.5.clone.1 = f32[8192,1]{1,0} multiply(add.56864.1.clone.1, bitcast.136254.22.clone.1) + add.56869.3.clone.1 = f32[8192,1]{1,0} add(subtract.27378.3.clone.1, multiply.60167.5.clone.1) + multiply.60168.3.clone.1 = f32[8192,1]{1,0} multiply(add.56866.1.clone.1, bitcast.136279.14.clone.1) + add.56870.1.clone.1 = f32[8192,1]{1,0} add(add.56869.3.clone.1, multiply.60168.3.clone.1) + multiply.60175.15.clone.1 = f32[8192,1]{1,0} multiply(add.56870.1.clone.1, broadcast.81461.143.clone.1) + subtract.27381.7.clone.1 = f32[8192,1]{1,0} subtract(subtract.27380.5.clone.1, multiply.60175.15.clone.1) + multiply.60169.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27370.1.clone.1, bitcast.136339.16.clone.1) + multiply.60170.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27371.1.clone.1, bitcast.136313.16.clone.1) + add.56871.3.clone.1 = f32[8192,1]{1,0} add(multiply.60169.5.clone.1, multiply.60170.5.clone.1) + multiply.60171.5.clone.1 = f32[8192,1]{1,0} multiply(add.56864.1.clone.1, bitcast.136279.14.clone.1) + subtract.27379.3.clone.1 = f32[8192,1]{1,0} subtract(add.56871.3.clone.1, multiply.60171.5.clone.1) + multiply.60172.3.clone.1 = f32[8192,1]{1,0} multiply(add.56866.1.clone.1, bitcast.136254.22.clone.1) + add.56872.1.clone.1 = f32[8192,1]{1,0} add(subtract.27379.3.clone.1, multiply.60172.3.clone.1) + multiply.60176.11.clone.1 = f32[8192,1]{1,0} multiply(add.56872.1.clone.1, broadcast.81461.143.clone.1) + subtract.27382.5.clone.1 = f32[8192,1]{1,0} subtract(subtract.27381.7.clone.1, multiply.60176.11.clone.1) + slice.39310.2.clone.1 = f32[1,8192,1]{2,1,0} slice(p1), slice={[0:1], [0:8192], [0:1]} + bitcast.136481.22.clone.1 = f32[8192,1]{1,0} bitcast(slice.39310.2.clone.1) + multiply.60179.7.clone.1 = f32[8192,1]{1,0} multiply(subtract.27382.5.clone.1, bitcast.136481.22.clone.1) + add.56873.1.clone.1 = f32[8192,1]{1,0} add(multiply.60173.9.clone.1, multiply.60174.9.clone.1) + add.56874.3.clone.1 = f32[8192,1]{1,0} add(add.56873.1.clone.1, multiply.60175.15.clone.1) + subtract.27383.1.clone.1 = f32[8192,1]{1,0} subtract(add.56874.3.clone.1, multiply.60176.11.clone.1) + slice.39311.8.clone.1 = f32[1,8192,1]{2,1,0} slice(p1), slice={[1:2], [0:8192], [0:1]} + bitcast.136498.14.clone.1 = f32[8192,1]{1,0} bitcast(slice.39311.8.clone.1) + multiply.60180.3.clone.1 = f32[8192,1]{1,0} multiply(subtract.27383.1.clone.1, bitcast.136498.14.clone.1) + subtract.27384.5.clone.1 = f32[8192,1]{1,0} subtract(multiply.60179.7.clone.1, multiply.60180.3.clone.1) + multiply.60181.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27376.1.clone.1, broadcast.81461.143.clone.1) + multiply.60182.11.clone.1 = f32[8192,1]{1,0} multiply(subtract.27377.1.clone.1, broadcast.81461.143.clone.1) + subtract.27385.5.clone.1 = f32[8192,1]{1,0} subtract(multiply.60181.5.clone.1, multiply.60182.11.clone.1) + multiply.60183.5.clone.1 = f32[8192,1]{1,0} multiply(add.56870.1.clone.1, broadcast.81498.40.clone.1) + add.56875.3.clone.1 = f32[8192,1]{1,0} add(subtract.27385.5.clone.1, multiply.60183.5.clone.1) + multiply.60184.5.clone.1 = f32[8192,1]{1,0} multiply(add.56872.1.clone.1, broadcast.81498.40.clone.1) + add.56876.1.clone.1 = f32[8192,1]{1,0} add(add.56875.3.clone.1, multiply.60184.5.clone.1) + slice.39312.12.clone.1 = f32[1,8192,1]{2,1,0} slice(p1), slice={[2:3], [0:8192], [0:1]} + bitcast.136532.16.clone.1 = f32[8192,1]{1,0} bitcast(slice.39312.12.clone.1) + multiply.60185.3.clone.1 = f32[8192,1]{1,0} multiply(add.56876.1.clone.1, bitcast.136532.16.clone.1) + subtract.27386.3.clone.1 = f32[8192,1]{1,0} subtract(subtract.27384.5.clone.1, multiply.60185.3.clone.1) + add.56877.9.clone.1 = f32[8192,1]{1,0} add(multiply.60181.5.clone.1, multiply.60182.11.clone.1) + subtract.27387.7.clone.1 = f32[8192,1]{1,0} subtract(add.56877.9.clone.1, multiply.60183.5.clone.1) + add.56878.5.clone.1 = f32[8192,1]{1,0} add(subtract.27387.7.clone.1, multiply.60184.5.clone.1) + slice.39313.12.clone.1 = f32[1,8192,1]{2,1,0} slice(p1), slice={[3:4], [0:8192], [0:1]} + bitcast.136550.16.clone.1 = f32[8192,1]{1,0} bitcast(slice.39313.12.clone.1) + multiply.60186.3.clone.1 = f32[8192,1]{1,0} multiply(add.56878.5.clone.1, bitcast.136550.16.clone.1) + subtract.27388.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27386.3.clone.1, multiply.60186.3.clone.1) + multiply.60199.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27388.1.clone.1, broadcast.81498.40.clone.1) + multiply.60187.3.clone.1 = f32[8192,1]{1,0} multiply(subtract.27382.5.clone.1, bitcast.136498.14.clone.1) + multiply.60188.7.clone.1 = f32[8192,1]{1,0} multiply(subtract.27383.1.clone.1, bitcast.136481.22.clone.1) + add.56879.5.clone.1 = f32[8192,1]{1,0} add(multiply.60187.3.clone.1, multiply.60188.7.clone.1) + multiply.60189.3.clone.1 = f32[8192,1]{1,0} multiply(add.56876.1.clone.1, bitcast.136550.16.clone.1) + add.56880.3.clone.1 = f32[8192,1]{1,0} add(add.56879.5.clone.1, multiply.60189.3.clone.1) + multiply.60190.3.clone.1 = f32[8192,1]{1,0} multiply(add.56878.5.clone.1, bitcast.136532.16.clone.1) + subtract.27389.1.clone.1 = f32[8192,1]{1,0} subtract(add.56880.3.clone.1, multiply.60190.3.clone.1) + multiply.60200.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27389.1.clone.1, broadcast.81498.40.clone.1) + subtract.27392.1.clone.1 = f32[8192,1]{1,0} subtract(multiply.60199.5.clone.1, multiply.60200.5.clone.1) + multiply.60191.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27382.5.clone.1, bitcast.136532.16.clone.1) + multiply.60192.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27383.1.clone.1, bitcast.136550.16.clone.1) + subtract.27390.3.clone.1 = f32[8192,1]{1,0} subtract(multiply.60191.5.clone.1, multiply.60192.5.clone.1) + multiply.60193.5.clone.1 = f32[8192,1]{1,0} multiply(add.56876.1.clone.1, bitcast.136481.22.clone.1) + add.56881.3.clone.1 = f32[8192,1]{1,0} add(subtract.27390.3.clone.1, multiply.60193.5.clone.1) + multiply.60194.3.clone.1 = f32[8192,1]{1,0} multiply(add.56878.5.clone.1, bitcast.136498.14.clone.1) + add.56882.1.clone.1 = f32[8192,1]{1,0} add(add.56881.3.clone.1, multiply.60194.3.clone.1) + multiply.60201.9.clone.1 = f32[8192,1]{1,0} multiply(add.56882.1.clone.1, broadcast.81461.143.clone.1) + subtract.27393.3.clone.1 = f32[8192,1]{1,0} subtract(subtract.27392.1.clone.1, multiply.60201.9.clone.1) + multiply.60195.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27382.5.clone.1, bitcast.136550.16.clone.1) + multiply.60196.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27383.1.clone.1, bitcast.136532.16.clone.1) + add.56883.3.clone.1 = f32[8192,1]{1,0} add(multiply.60195.5.clone.1, multiply.60196.5.clone.1) + multiply.60197.5.clone.1 = f32[8192,1]{1,0} multiply(add.56876.1.clone.1, bitcast.136498.14.clone.1) + subtract.27391.3.clone.1 = f32[8192,1]{1,0} subtract(add.56883.3.clone.1, multiply.60197.5.clone.1) + multiply.60198.3.clone.1 = f32[8192,1]{1,0} multiply(add.56878.5.clone.1, bitcast.136481.22.clone.1) + add.56884.1.clone.1 = f32[8192,1]{1,0} add(subtract.27391.3.clone.1, multiply.60198.3.clone.1) + multiply.60202.5.clone.1 = f32[8192,1]{1,0} multiply(add.56884.1.clone.1, broadcast.81461.143.clone.1) + subtract.27394.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27393.3.clone.1, multiply.60202.5.clone.1) + slice.39315.2.clone.1 = f32[1,8192,1]{2,1,0} slice(p0), slice={[0:1], [0:8192], [0:1]} + bitcast.136692.22.clone.1 = f32[8192,1]{1,0} bitcast(slice.39315.2.clone.1) + multiply.60205.7.clone.1 = f32[8192,1]{1,0} multiply(subtract.27394.1.clone.1, bitcast.136692.22.clone.1) + add.56885.1.clone.1 = f32[8192,1]{1,0} add(multiply.60199.5.clone.1, multiply.60200.5.clone.1) + add.56886.3.clone.1 = f32[8192,1]{1,0} add(add.56885.1.clone.1, multiply.60201.9.clone.1) + subtract.27395.1.clone.1 = f32[8192,1]{1,0} subtract(add.56886.3.clone.1, multiply.60202.5.clone.1) + slice.39316.8.clone.1 = f32[1,8192,1]{2,1,0} slice(p0), slice={[1:2], [0:8192], [0:1]} + bitcast.136709.14.clone.1 = f32[8192,1]{1,0} bitcast(slice.39316.8.clone.1) + multiply.60206.3.clone.1 = f32[8192,1]{1,0} multiply(subtract.27395.1.clone.1, bitcast.136709.14.clone.1) + subtract.27396.5.clone.1 = f32[8192,1]{1,0} subtract(multiply.60205.7.clone.1, multiply.60206.3.clone.1) + multiply.60207.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27388.1.clone.1, broadcast.81461.143.clone.1) + multiply.60208.11.clone.1 = f32[8192,1]{1,0} multiply(subtract.27389.1.clone.1, broadcast.81461.143.clone.1) + subtract.27397.5.clone.1 = f32[8192,1]{1,0} subtract(multiply.60207.5.clone.1, multiply.60208.11.clone.1) + multiply.60209.5.clone.1 = f32[8192,1]{1,0} multiply(add.56882.1.clone.1, broadcast.81498.40.clone.1) + add.56887.3.clone.1 = f32[8192,1]{1,0} add(subtract.27397.5.clone.1, multiply.60209.5.clone.1) + multiply.60210.5.clone.1 = f32[8192,1]{1,0} multiply(add.56884.1.clone.1, broadcast.81498.40.clone.1) + add.56888.1.clone.1 = f32[8192,1]{1,0} add(add.56887.3.clone.1, multiply.60210.5.clone.1) + slice.39317.12.clone.1 = f32[1,8192,1]{2,1,0} slice(p0), slice={[2:3], [0:8192], [0:1]} + bitcast.136743.16.clone.1 = f32[8192,1]{1,0} bitcast(slice.39317.12.clone.1) + multiply.60211.3.clone.1 = f32[8192,1]{1,0} multiply(add.56888.1.clone.1, bitcast.136743.16.clone.1) + subtract.27398.3.clone.1 = f32[8192,1]{1,0} subtract(subtract.27396.5.clone.1, multiply.60211.3.clone.1) + add.56889.5.clone.1 = f32[8192,1]{1,0} add(multiply.60207.5.clone.1, multiply.60208.11.clone.1) + subtract.27399.3.clone.1 = f32[8192,1]{1,0} subtract(add.56889.5.clone.1, multiply.60209.5.clone.1) + add.56890.1.clone.1 = f32[8192,1]{1,0} add(subtract.27399.3.clone.1, multiply.60210.5.clone.1) + slice.39318.12.clone.1 = f32[1,8192,1]{2,1,0} slice(p0), slice={[3:4], [0:8192], [0:1]} + bitcast.136761.16.clone.1 = f32[8192,1]{1,0} bitcast(slice.39318.12.clone.1) + multiply.60212.3.clone.1 = f32[8192,1]{1,0} multiply(add.56890.1.clone.1, bitcast.136761.16.clone.1) + subtract.27400.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27398.3.clone.1, multiply.60212.3.clone.1) + multiply.60213.3.clone.1 = f32[8192,1]{1,0} multiply(subtract.27394.1.clone.1, bitcast.136709.14.clone.1) + multiply.60214.7.clone.1 = f32[8192,1]{1,0} multiply(subtract.27395.1.clone.1, bitcast.136692.22.clone.1) + add.56891.5.clone.1 = f32[8192,1]{1,0} add(multiply.60213.3.clone.1, multiply.60214.7.clone.1) + multiply.60215.3.clone.1 = f32[8192,1]{1,0} multiply(add.56888.1.clone.1, bitcast.136761.16.clone.1) + add.56892.3.clone.1 = f32[8192,1]{1,0} add(add.56891.5.clone.1, multiply.60215.3.clone.1) + multiply.60216.3.clone.1 = f32[8192,1]{1,0} multiply(add.56890.1.clone.1, bitcast.136743.16.clone.1) + subtract.27401.1.clone.1 = f32[8192,1]{1,0} subtract(add.56892.3.clone.1, multiply.60216.3.clone.1) + multiply.60225.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27401.1.clone.1, broadcast.81461.143.clone.1) + subtract.27404.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27400.1.clone.1, multiply.60225.5.clone.1) + multiply.60217.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27394.1.clone.1, bitcast.136743.16.clone.1) + multiply.60218.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27395.1.clone.1, bitcast.136761.16.clone.1) + subtract.27402.3.clone.1 = f32[8192,1]{1,0} subtract(multiply.60217.5.clone.1, multiply.60218.5.clone.1) + multiply.60219.5.clone.1 = f32[8192,1]{1,0} multiply(add.56888.1.clone.1, bitcast.136692.22.clone.1) + add.56893.3.clone.1 = f32[8192,1]{1,0} add(subtract.27402.3.clone.1, multiply.60219.5.clone.1) + multiply.60220.3.clone.1 = f32[8192,1]{1,0} multiply(add.56890.1.clone.1, bitcast.136709.14.clone.1) + add.56894.1.clone.1 = f32[8192,1]{1,0} add(add.56893.3.clone.1, multiply.60220.3.clone.1) + multiply.60226.5.clone.1 = f32[8192,1]{1,0} multiply(add.56894.1.clone.1, broadcast.81461.143.clone.1) + subtract.27405.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27404.1.clone.1, multiply.60226.5.clone.1) + multiply.60221.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27394.1.clone.1, bitcast.136761.16.clone.1) + multiply.60222.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27395.1.clone.1, bitcast.136743.16.clone.1) + add.56895.3.clone.1 = f32[8192,1]{1,0} add(multiply.60221.5.clone.1, multiply.60222.5.clone.1) + multiply.60223.5.clone.1 = f32[8192,1]{1,0} multiply(add.56888.1.clone.1, bitcast.136709.14.clone.1) + subtract.27403.3.clone.1 = f32[8192,1]{1,0} subtract(add.56895.3.clone.1, multiply.60223.5.clone.1) + multiply.60224.3.clone.1 = f32[8192,1]{1,0} multiply(add.56890.1.clone.1, bitcast.136692.22.clone.1) + add.56896.1.clone.1 = f32[8192,1]{1,0} add(subtract.27403.3.clone.1, multiply.60224.3.clone.1) + multiply.60227.5.clone.1 = f32[8192,1]{1,0} multiply(add.56896.1.clone.1, broadcast.81461.143.clone.1) + subtract.27406.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27405.1.clone.1, multiply.60227.5.clone.1) + multiply.60228.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27400.1.clone.1, broadcast.81461.143.clone.1) + add.56897.1.clone.1 = f32[8192,1]{1,0} add(multiply.60228.5.clone.1, subtract.27401.1.clone.1) + add.56898.1.clone.1 = f32[8192,1]{1,0} add(add.56897.1.clone.1, multiply.60226.5.clone.1) + subtract.27407.1.clone.1 = f32[8192,1]{1,0} subtract(add.56898.1.clone.1, multiply.60227.5.clone.1) + multiply.60229.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27407.1.clone.1, broadcast.81461.143.clone.1) + subtract.27410.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27406.1.clone.1, multiply.60229.5.clone.1) + subtract.27408.1.clone.1 = f32[8192,1]{1,0} subtract(multiply.60228.5.clone.1, multiply.60225.5.clone.1) + add.56899.1.clone.1 = f32[8192,1]{1,0} add(subtract.27408.1.clone.1, add.56894.1.clone.1) + add.56900.1.clone.1 = f32[8192,1]{1,0} add(add.56899.1.clone.1, multiply.60227.5.clone.1) + multiply.60230.5.clone.1 = f32[8192,1]{1,0} multiply(add.56900.1.clone.1, broadcast.81461.143.clone.1) + subtract.27411.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27410.1.clone.1, multiply.60230.5.clone.1) + add.56901.1.clone.1 = f32[8192,1]{1,0} add(multiply.60228.5.clone.1, multiply.60225.5.clone.1) + subtract.27409.1.clone.1 = f32[8192,1]{1,0} subtract(add.56901.1.clone.1, multiply.60226.5.clone.1) + add.56902.1.clone.1 = f32[8192,1]{1,0} add(subtract.27409.1.clone.1, add.56896.1.clone.1) + multiply.60231.5.clone.1 = f32[8192,1]{1,0} multiply(add.56902.1.clone.1, broadcast.81461.143.clone.1) + subtract.27412.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27411.1.clone.1, multiply.60231.5.clone.1) + multiply.60232.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27406.1.clone.1, broadcast.81461.143.clone.1) + add.56903.1.clone.1 = f32[8192,1]{1,0} add(multiply.60232.5.clone.1, subtract.27407.1.clone.1) + add.56904.1.clone.1 = f32[8192,1]{1,0} add(add.56903.1.clone.1, multiply.60230.5.clone.1) + subtract.27413.1.clone.1 = f32[8192,1]{1,0} subtract(add.56904.1.clone.1, multiply.60231.5.clone.1) + multiply.60233.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27413.1.clone.1, broadcast.81461.143.clone.1) + subtract.27416.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27412.1.clone.1, multiply.60233.5.clone.1) + subtract.27414.1.clone.1 = f32[8192,1]{1,0} subtract(multiply.60232.5.clone.1, multiply.60229.5.clone.1) + add.56905.1.clone.1 = f32[8192,1]{1,0} add(subtract.27414.1.clone.1, add.56900.1.clone.1) + add.56906.1.clone.1 = f32[8192,1]{1,0} add(add.56905.1.clone.1, multiply.60231.5.clone.1) + multiply.60234.5.clone.1 = f32[8192,1]{1,0} multiply(add.56906.1.clone.1, broadcast.81461.143.clone.1) + subtract.27417.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27416.1.clone.1, multiply.60234.5.clone.1) + add.56907.1.clone.1 = f32[8192,1]{1,0} add(multiply.60232.5.clone.1, multiply.60229.5.clone.1) + subtract.27415.1.clone.1 = f32[8192,1]{1,0} subtract(add.56907.1.clone.1, multiply.60230.5.clone.1) + add.56908.1.clone.1 = f32[8192,1]{1,0} add(subtract.27415.1.clone.1, add.56902.1.clone.1) + multiply.60235.5.clone.1 = f32[8192,1]{1,0} multiply(add.56908.1.clone.1, broadcast.81461.143.clone.1) + subtract.27418.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27417.1.clone.1, multiply.60235.5.clone.1) + multiply.60236.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27412.1.clone.1, broadcast.81461.143.clone.1) + add.56909.1.clone.1 = f32[8192,1]{1,0} add(multiply.60236.5.clone.1, subtract.27413.1.clone.1) + add.56910.1.clone.1 = f32[8192,1]{1,0} add(add.56909.1.clone.1, multiply.60234.5.clone.1) + subtract.27419.1.clone.1 = f32[8192,1]{1,0} subtract(add.56910.1.clone.1, multiply.60235.5.clone.1) + multiply.60237.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27419.1.clone.1, broadcast.81461.143.clone.1) + subtract.27422.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27418.1.clone.1, multiply.60237.5.clone.1) + subtract.27420.1.clone.1 = f32[8192,1]{1,0} subtract(multiply.60236.5.clone.1, multiply.60233.5.clone.1) + add.56911.1.clone.1 = f32[8192,1]{1,0} add(subtract.27420.1.clone.1, add.56906.1.clone.1) + add.56912.1.clone.1 = f32[8192,1]{1,0} add(add.56911.1.clone.1, multiply.60235.5.clone.1) + multiply.60238.5.clone.1 = f32[8192,1]{1,0} multiply(add.56912.1.clone.1, broadcast.81461.143.clone.1) + subtract.27423.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27422.1.clone.1, multiply.60238.5.clone.1) + add.56913.1.clone.1 = f32[8192,1]{1,0} add(multiply.60236.5.clone.1, multiply.60233.5.clone.1) + subtract.27421.1.clone.1 = f32[8192,1]{1,0} subtract(add.56913.1.clone.1, multiply.60234.5.clone.1) + add.56914.1.clone.1 = f32[8192,1]{1,0} add(subtract.27421.1.clone.1, add.56908.1.clone.1) + multiply.60239.5.clone.1 = f32[8192,1]{1,0} multiply(add.56914.1.clone.1, broadcast.81461.143.clone.1) + subtract.27424.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27423.1.clone.1, multiply.60239.5.clone.1) + multiply.60241.7.clone.1 = f32[8192,1]{1,0} multiply(subtract.27424.1.clone.1, broadcast.81498.40.clone.1) + multiply.60240.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27418.1.clone.1, broadcast.81461.143.clone.1) + add.56915.1.clone.1 = f32[8192,1]{1,0} add(multiply.60240.5.clone.1, subtract.27419.1.clone.1) + add.56916.1.clone.1 = f32[8192,1]{1,0} add(add.56915.1.clone.1, multiply.60238.5.clone.1) + subtract.27425.1.clone.1 = f32[8192,1]{1,0} subtract(add.56916.1.clone.1, multiply.60239.5.clone.1) + multiply.60242.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27425.1.clone.1, broadcast.81461.143.clone.1) + subtract.27428.5.clone.1 = f32[8192,1]{1,0} subtract(multiply.60241.7.clone.1, multiply.60242.5.clone.1) + subtract.27426.1.clone.1 = f32[8192,1]{1,0} subtract(multiply.60240.5.clone.1, multiply.60237.5.clone.1) + add.56917.1.clone.1 = f32[8192,1]{1,0} add(subtract.27426.1.clone.1, add.56912.1.clone.1) + add.56918.1.clone.1 = f32[8192,1]{1,0} add(add.56917.1.clone.1, multiply.60239.5.clone.1) + multiply.60243.5.clone.1 = f32[8192,1]{1,0} multiply(add.56918.1.clone.1, broadcast.81461.143.clone.1) + subtract.27429.3.clone.1 = f32[8192,1]{1,0} subtract(subtract.27428.5.clone.1, multiply.60243.5.clone.1) + add.56919.1.clone.1 = f32[8192,1]{1,0} add(multiply.60240.5.clone.1, multiply.60237.5.clone.1) + subtract.27427.1.clone.1 = f32[8192,1]{1,0} subtract(add.56919.1.clone.1, multiply.60238.5.clone.1) + add.56920.1.clone.1 = f32[8192,1]{1,0} add(subtract.27427.1.clone.1, add.56914.1.clone.1) + multiply.60244.3.clone.1 = f32[8192,1]{1,0} multiply(add.56920.1.clone.1, broadcast.81521.32.clone.1) + subtract.27430.1.clone.1 = f32[8192,1]{1,0} subtract(subtract.27429.3.clone.1, multiply.60244.3.clone.1) + bitcast.137156.7 = f32[1,1,8192]{2,1,0} bitcast(subtract.27430.1.clone.1) + multiply.60245.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27424.1.clone.1, broadcast.81461.143.clone.1) + multiply.60246.5.clone.1 = f32[8192,1]{1,0} multiply(subtract.27425.1.clone.1, broadcast.81498.40.clone.1) + add.56921.3.clone.1 = f32[8192,1]{1,0} add(multiply.60245.5.clone.1, multiply.60246.5.clone.1) + multiply.60247.5.clone.1 = f32[8192,1]{1,0} multiply(add.56918.1.clone.1, broadcast.81521.32.clone.1) + add.56922.3.clone.1 = f32[8192,1]{1,0} add(add.56921.3.clone.1, multiply.60247.5.clone.1) + multiply.60248.5.clone.1 = f32[8192,1]{1,0} multiply(add.56920.1.clone.1, broadcast.81461.143.clone.1) + subtract.27431.1.clone.1 = f32[8192,1]{1,0} subtract(add.56922.3.clone.1, multiply.60248.5.clone.1) + bitcast.137185.7 = f32[1,1,8192]{2,1,0} bitcast(subtract.27431.1.clone.1) + multiply.60249.7.clone.1 = f32[8192,1]{1,0} multiply(subtract.27425.1.clone.1, broadcast.81521.32.clone.1) + subtract.27432.5.clone.1 = f32[8192,1]{1,0} subtract(multiply.60245.5.clone.1, multiply.60249.7.clone.1) + multiply.60250.3.clone.1 = f32[8192,1]{1,0} multiply(add.56918.1.clone.1, broadcast.81498.40.clone.1) + add.56923.3.clone.1 = f32[8192,1]{1,0} add(subtract.27432.5.clone.1, multiply.60250.3.clone.1) + add.56924.1.clone.1 = f32[8192,1]{1,0} add(add.56923.3.clone.1, multiply.60248.5.clone.1) + bitcast.137206.7 = f32[1,1,8192]{2,1,0} bitcast(add.56924.1.clone.1) + multiply.60251.7.clone.1 = f32[8192,1]{1,0} multiply(subtract.27424.1.clone.1, broadcast.81521.32.clone.1) + add.56925.5.clone.1 = f32[8192,1]{1,0} add(multiply.60251.7.clone.1, multiply.60242.5.clone.1) + subtract.27433.3.clone.1 = f32[8192,1]{1,0} subtract(add.56925.5.clone.1, multiply.60243.5.clone.1) + multiply.60252.3.clone.1 = f32[8192,1]{1,0} multiply(add.56920.1.clone.1, broadcast.81498.40.clone.1) + add.56926.1.clone.1 = f32[8192,1]{1,0} add(subtract.27433.3.clone.1, multiply.60252.3.clone.1) + bitcast.137227.7 = f32[1,1,8192]{2,1,0} bitcast(add.56926.1.clone.1) + concatenate.12201.7 = f32[4,1,8192]{2,1,0} concatenate(bitcast.137156.7, bitcast.137185.7, bitcast.137206.7, bitcast.137227.7), dimensions={0} + bitcast.137232.1 = f32[4,8192]{1,0} bitcast(concatenate.12201.7) + slice.39319.1 = f32[1,8192]{1,0} slice(bitcast.137232.1), slice={[0:1], [0:8192]} + slice.39320.1.clone.1 = f32[1,8192]{1,0} slice(bitcast.137232.1), slice={[1:2], [0:8192]} + slice.39321.1.clone.1 = f32[1,8192]{1,0} slice(bitcast.137232.1), slice={[2:3], [0:8192]} + slice.39322.1.clone.1 = f32[1,8192]{1,0} slice(bitcast.137232.1), slice={[3:4], [0:8192]} + ROOT tuple.3653 = ( + f32[1,8192]{1,0}, f32[1,8192]{1,0}, f32[1,8192]{1,0}, f32[1,8192]{1,0}, f32[8192,1]{1,0}, +/*index=5*/ f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=10*/f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=15*/f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=20*/f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=25*/f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=30*/f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=35*/f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=40*/f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=45*/f32[8192,1]{1,0}) tuple( + slice.39319.1, slice.39320.1.clone.1, slice.39321.1.clone.1, slice.39322.1.clone.1, subtract.27430.1.clone.1, +/*index=5*/ subtract.27431.1.clone.1, add.56924.1.clone.1, add.56926.1.clone.1, add.56918.1.clone.1, subtract.27425.1.clone.1, +/*index=10*/add.56920.1.clone.1, subtract.27424.1.clone.1, add.56912.1.clone.1, subtract.27419.1.clone.1, add.56914.1.clone.1, +/*index=15*/subtract.27418.1.clone.1, subtract.27413.1.clone.1, add.56906.1.clone.1, subtract.27412.1.clone.1, add.56908.1.clone.1, +/*index=20*/add.56900.1.clone.1, subtract.27407.1.clone.1, subtract.27406.1.clone.1, add.56902.1.clone.1, subtract.27401.1.clone.1, +/*index=25*/add.56894.1.clone.1, subtract.27400.1.clone.1, add.56896.1.clone.1, subtract.27394.1.clone.1, subtract.27395.1.clone.1, +/*index=30*/add.56888.1.clone.1, add.56890.1.clone.1, add.56882.1.clone.1, add.56884.1.clone.1, subtract.27388.1.clone.1, +/*index=35*/subtract.27389.1.clone.1, add.56876.1.clone.1, subtract.27383.1.clone.1, subtract.27376.1.clone.1, subtract.27377.1.clone.1, +/*index=40*/add.56870.1.clone.1, add.56872.1.clone.1, subtract.27370.1.clone.1, subtract.27371.1.clone.1, add.56864.1.clone.1, +/*index=45*/add.56866.1.clone.1) +} // fused_add_slice_subtrac + +ENTRY main { + p6 = f32[8192,1]{1,0} parameter(6) + p5 = f32[8192,1]{1,0} parameter(5) + p4 = f32[8192,1]{1,0} parameter(4) + p3 = f32[8192,1]{1,0} parameter(3) + p2 = f32[4,8192,1]{2,1,0} parameter(2) + p1 = f32[4,8192,1]{2,1,0} parameter(1) + p0 = f32[4,8192,1]{2,1,0} parameter(0) + ROOT fusion = ( + f32[1,8192]{1,0}, f32[1,8192]{1,0}, f32[1,8192]{1,0}, f32[1,8192]{1,0}, f32[8192,1]{1,0}, +/*index=5*/ f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=10*/f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=15*/f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=20*/f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=25*/f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=30*/f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=35*/f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=40*/f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, f32[8192,1]{1,0}, +/*index=45*/f32[8192,1]{1,0}) fusion( + p0, p1, p2, p3, p4, p5, p6), kind=kLoop, calls=f +} diff --git a/xla/codegen/emitters/ir/xla_dialect.cc b/xla/codegen/emitters/ir/xla_dialect.cc index 04140de8805722..aa8d3c6bfd76ea 100644 --- a/xla/codegen/emitters/ir/xla_dialect.cc +++ b/xla/codegen/emitters/ir/xla_dialect.cc @@ -13,9 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep #include "mlir/IR/OpImplementation.h" // IWYU pragma: keep +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/InliningUtils.h" #include "xla/codegen/emitters/ir/xla_ops.h" @@ -57,22 +64,17 @@ struct XlaInlinerInterface : public mlir::DialectInlinerInterface { return false; } - // If callee and caller call the same third function, inline. We have no - // guarantee that the indices are the same, but there is a good chance they - // are (or if the callee gets inlined as well, there will be CSE - // opportunities). - // This is duct tape to work around the limitations of our partitioner. - // Ideally, the partitioner would be aware of the actual indexing and create - // the partitions based on it (i.e., the case where the indices are the same - // would never happen). + // If callee calls a subset of functions of its caller, then we inline. llvm::SmallDenseSet callee_calls; for (auto call : region->getOps()) { callee_calls.insert(call.getCallee()); } + llvm::SmallDenseSet caller_calls; for (auto call : call->getParentRegion()->getOps()) { - if (callee_calls.contains(call.getCallee())) { - return true; - } + caller_calls.insert(call.getCallee()); + } + if (!llvm::set_is_subset(callee_calls, caller_calls)) { + return false; } constexpr int kMaxOperationsToInline = 8;