编辑代码

def merge_sort_and_count_inversions(arr):
    if len(arr) <= 1:
        return arr, 0

    mid = len(arr) // 2
    left, inversions_left = merge_sort_and_count_inversions(arr[:mid])
    right, inversions_right = merge_sort_and_count_inversions(arr[mid:])
    merged, inversions = merge_and_count_inversions(left, right)

    total_inversions = inversions + inversions_left + inversions_right

    return merged, total_inversions

def merge_and_count_inversions(left, right):
    result = []
    inversions = 0
    i = j = 0

    while i < len(left) and j < len(right):
        if left[i] <= right[j]:
            result.append(left[i])
            i += 1
        else:
            result.append(right[j])
            j += 1
            inversions += len(left) - i

    result.extend(left[i:])
    result.extend(right[j:])

    return result, inversions