Skip to content

Commit

Permalink
zsk/switch variable name (DeepLink-org#759)
Browse files Browse the repository at this point in the history
switch variable name
tensor1, tensor2 -> tensor_dev, tensor_ref
  • Loading branch information
zsksmhq authored Dec 20, 2023
1 parent 1d8a97c commit f686147
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions diopi_test/python/conformance/check_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,55 +72,55 @@ def compare_num(output, output_reference, **kwargs):
CheckResult.allclose(output, output_reference, **kwargs)

@staticmethod
def allclose(tensor1: np.ndarray, tensor2: np.ndarray, **kwargs) -> bool:
def allclose(tensor_dev: np.ndarray, tensor_ref: np.ndarray, **kwargs) -> bool:
var_name = kwargs.get('name', 'out')
sum_to_compare = kwargs.get('sum_to_compare', False)
rtol = kwargs.get('rtol', 1e-5)
atol = kwargs.get('atol', 1e-8)
mismatch_ratio_threshold = kwargs.get('mismatch_ratio_threshold', 1e-3)
tensor1 = np.sum(tensor1) if sum_to_compare else tensor1
tensor2 = np.sum(tensor2) if sum_to_compare else tensor2
matched = np.isclose(tensor1, tensor2, rtol, atol, equal_nan=True)
tensor_dev = np.sum(tensor_dev) if sum_to_compare else tensor_dev
tensor_ref = np.sum(tensor_ref) if sum_to_compare else tensor_ref
matched = np.isclose(tensor_dev, tensor_ref, rtol, atol, equal_nan=True)
mismatched_num = matched.size - np.sum(matched)
passed = mismatched_num <= mismatch_ratio_threshold * matched.size
glob_vars.func_status[glob_vars.cur_test_func] = 'passed'
if not passed:
glob_vars.func_status[glob_vars.cur_test_func] = 'failed'
sum1 = tensor1.sum()
sum2 = tensor2.sum()
mask = np.isclose(tensor1, tensor2, rtol, atol, equal_nan=True)
sum1 = tensor_dev.sum()
sum2 = tensor_ref.sum()
mask = np.isclose(tensor_dev, tensor_ref, rtol, atol, equal_nan=True)
count = np.count_nonzero(np.equal(mask, False))
debug_level = glob_vars.debug_level
if tensor1.dtype == np.bool_:
if tensor_dev.dtype == np.bool_:
max_diff = 1
error_info = f"The count of elements that do not meet the accuracy requirement is {count}.\n" + \
f"Max of diff is {max_diff}.\n"
else:
assert tensor1.size == tensor2.size, "tensor1 element num does not equal tensor2's."
assert tensor_dev.size == tensor_ref.size, "tensor_dev element num does not equal tensor_ref's."
error_info = f"The count of elements that do not meet the accuracy requirement is {count}.\n" + \
f"The dtype of {var_name} is {tensor1.dtype}.\n" + \
f"The shape of {var_name} is {tensor1.shape}.\n" + \
f"The stride of {var_name} is {np.divide(tensor1.strides, tensor1.itemsize).astype(np.int32)}.\n"
nan_index = np.isnan(tensor1) | np.isnan(tensor2)
f"The dtype of {var_name} is {tensor_dev.dtype}.\n" + \
f"The shape of {var_name} is {tensor_dev.shape}.\n" + \
f"The stride of {var_name} is {np.divide(tensor_dev.strides, tensor_dev.itemsize).astype(np.int32)}.\n"
nan_index = np.isnan(tensor_dev) | np.isnan(tensor_ref)
nan_index[matched] = False # mismatched nan number index
if (len(np.argwhere(nan_index)) > 0):
# nan number exists
error_info += f"Exist mismatched nan number. E.g., the actual val is {tensor1[nan_index].ravel()[0]} and the expected is {tensor2[nan_index].ravel()[0]}.\n"
error_info += f"Exist mismatched nan number. E.g., the actual val is {tensor_dev[nan_index].ravel()[0]} and the expected is {tensor_ref[nan_index].ravel()[0]}.\n"
if (not np.array_equal(nan_index, ~matched)):
# not all mimatched numbers are nan,exists different numbers
diff = np.abs(tensor1 - tensor2)
diff = np.abs(tensor_dev - tensor_ref)
diff[matched] = 0
diff[nan_index] = 0
max_diff = diff.max()
max_diff_index = np.unravel_index(np.argmax(diff), diff.shape)
max_diff_elem = tensor1[max_diff_index]
max_diff_elem_ref = tensor2[max_diff_index]
max_diff_elem = tensor_dev[max_diff_index]
max_diff_elem_ref = tensor_ref[max_diff_index]
error_info += f"The max of diff is {max_diff}. Specifically, the actual val is {max_diff_elem} and the expected is {max_diff_elem_ref}.\n"
if debug_level > 0:
if np.isnan(sum1) or np.isnan(sum2):
error_info += f"Exists nan, {var_name} is {sum1} and {var_name}_ref is {sum2}.\n"
else:
error_info += f"Sum of {var_name} is {sum1}, Sum of {var_name}_ref is {sum2}, Max of diff is {max_diff}.\n"
if debug_level > 1:
error_info += f"{var_name} is {tensor1},\n{var_name}_ref is {tensor2},\nMask is {mask}\n"
error_info += f"{var_name} is {tensor_dev},\n{var_name}_ref is {tensor_ref},\nMask is {mask}\n"
raise OutputCheckFailedException(error_info)

0 comments on commit f686147

Please sign in to comment.