1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
use primitive_types::U256;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum Rounding {
Nearest,
Down,
Up,
}
impl Rounding {
pub fn to_bias(self, magnitude: u128) -> (u128, u128) {
match self {
Rounding::Nearest => (0, 0),
Rounding::Down => (0, magnitude),
Rounding::Up => (magnitude, 0),
}
}
}
pub fn round_to_rational((n, d): (U256, U256), rounding: Rounding) -> (u128, u128) {
let shift = n.bits().max(d.bits()).saturating_sub(128);
let (n, d) = if shift > 0 {
let min_n = if n.is_zero() { 0 } else { 1 };
let (bias_n, bias_d) = rounding.to_bias(1);
let shifted_n = (n >> shift).low_u128();
let shifted_d = (d >> shift).low_u128();
(
shifted_n.saturating_add(bias_n).max(min_n),
shifted_d.saturating_add(bias_d).max(1),
)
} else {
(n.low_u128(), d.low_u128())
};
(n, d)
}
#[test]
fn round_to_rational_should_work() {
let res = round_to_rational((U256::from(1), U256::from(1)), Rounding::Nearest);
let expected: (u128, u128) = (1, 1);
assert_eq!(res, expected);
let res = round_to_rational((U256::MAX, U256::MAX), Rounding::Nearest);
let expected = (u128::MAX, u128::MAX);
assert_eq!(res, expected);
let res = round_to_rational((U256::MAX, U256::from(1)), Rounding::Nearest);
let expected = (u128::MAX, 1u128);
assert_eq!(res, expected);
let res = round_to_rational((U256::from(1), U256::MAX), Rounding::Nearest);
let expected = (1u128, u128::MAX);
assert_eq!(res, expected);
}