Skip to content

Commit

Permalink
Add mismatch_ratio_threshold=0 for scatter (DeepLink-org#763)
Browse files Browse the repository at this point in the history
* add mismatch_ratio_threshold=0 for scatter

* increase rtol for batch_norm_stats on Camb
  • Loading branch information
jfxu-st authored Dec 21, 2023
1 parent f60cad1 commit 58436c3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
5 changes: 5 additions & 0 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6646,6 +6646,7 @@
name=['scatter'],
interface=['torch'],
is_inplace=True,
mismatch_ratio_threshold=0,
para=dict(
dim=[0, -1, 1, -2, 2, 1, -1],
),
Expand Down Expand Up @@ -6679,6 +6680,7 @@
name=['scatter'],
interface=['torch'],
is_inplace=True,
mismatch_ratio_threshold=0,
para=dict(
dim=[0, -1, 1, 2],
),
Expand Down Expand Up @@ -6711,6 +6713,7 @@
name=['scatter'],
interface=['torch'],
is_inplace=True,
mismatch_ratio_threshold=0,
para=dict(
dim=[2, 1],
reduce=['add', 'multiply']
Expand Down Expand Up @@ -6742,6 +6745,7 @@
name=['scatter'],
interface=['torch'],
is_inplace=True,
mismatch_ratio_threshold=0,
para=dict(
dim=[0, -1, 1, -2, 2, 1, -1],
value=[True, 0.25, -100, 0, 2.34, 20, 1e-4],
Expand Down Expand Up @@ -6769,6 +6773,7 @@
name=['scatter'],
interface=['torch'],
is_inplace=True,
mismatch_ratio_threshold=0,
para=dict(
dim=[2, 1],
value=[-2.31, float("-inf")],
Expand Down
2 changes: 1 addition & 1 deletion impl/camb/device_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,7 +1914,7 @@
'batch_norm_stats': dict(
name=["batch_norm_stats"],
atol=1e-2,
rtol=1e-3,
rtol=5e-3,
),

'rotary_emb': dict(
Expand Down

0 comments on commit 58436c3

Please sign in to comment.