Function sp_arithmetic::normalize

source ·
pub fn normalize<T>(input: &[T], targeted_sum: T) -> Result<Vec<T>, &'static str>where
    T: Clone + Copy + Ord + BaseArithmetic + Unsigned + Debug,
Expand description

Normalize input so that the sum of all elements reaches targeted_sum.

This implementation is currently in a balanced position between being performant and accurate.

  1. We prefer storing original indices, and sorting the input only once. This will save the cost of sorting per round at the cost of a little bit of memory.
  2. The granularity of increment/decrements is determined by the number of elements in input and their sum difference with targeted_sum, namely diff = diff(sum(input), target_sum). This value is then distributed into per_round = diff / input.len() and leftover = diff % round. First, per_round is applied to all elements of input, and then we move to leftover, in which case we add/subtract 1 by 1 until leftover is depleted.

When the sum is less than the target, the above approach always holds. In this case, then each individual element is also less than target. Thus, by adding per_round to each item, neither of them can overflow the numeric bound of T. In fact, neither of the can go beyond target_sum*.

If sum is more than target, there is small twist. The subtraction of per_round form each element might go below zero. In this case, we saturate and add the error to the leftover value. This ensures that the result will always stay accurate, yet it might cause the execution to become increasingly slow, since leftovers are applied one by one.

All in all, the complicated case above is rare to happen in most use cases within this repo , hence we opt for it due to its simplicity.

This function will return an error is if length of input cannot fit in T, or if sum(input) cannot fit inside T.

  • This proof is used in the implementation as well.