In general, we want to pick max(A_i, B_i) whenever possible in order to stick close to this hard upper bound.
However, when this results in the sum being a multiple of K, we might need to take some losses in order to make it work.
Going forward, assume A_i >= B_i.
Key observations:
1. If A_i % K == B_i % K, we never take B_i.
2. When we need to change a subarray's parity and it's possible to do so, we only need to take one value from B.
Both can be proven using exchange argument.
Once you have this foundation the problem becomes a lot more approachable.
There are lots of paths you can take from here, but personally I used a vaguely DNQ-like monotonic stack idea.
```cpp
void solve() {
int n, k;
cin >> n >> k;
vector<int> a(n), flex(n), inds;
for (int i = 0; i < n; i++) {
int x, y;
cin >> x >> y;
if (x < y) {
swap(x, y);
}
a[i] = x;
flex[i] = y % k != x % k ? x - y : INF;
if (flex[i] != INF) {
inds.pb(i);
}
}
mi ans = 0;
// Consider subarrays between flex elements
inds.insert(begin(inds), -1);
inds.pb(n);
for (int i = 1; i < sz(inds); i++) {
ll cnt = 1;
mi sum = 0;
map<int, ll> rem_cnt;
map<int, mi> rem_sum;
rem_cnt[0] = 1;
rem_sum[0] = 0;
ll psum = 0;
for (int j = inds[i - 1] + 1; j < inds[i]; j++) {
psum += a[j];
ans += mi(cnt - rem_cnt[psum % k]) * psum;
ans -= sum - rem_sum[psum % k];
cnt++;
sum += psum;
rem_cnt[psum % k] += 1;
rem_sum[psum % k] += psum;
}
}
inds.erase(begin(inds));
inds.pop_back();
// For everything else, we take the upper bound of max(a, b), excluding only
// the cheapest element that changes our remainder if that's invalid
stack<int> s;
vector<int> prev(n);
for (int i = 0; i < n; i++) {
while (sz(s) && flex[s.top()] > flex[i]) {
s.pop();
}
prev[i] = sz(s) ? s.top() : -1;
s.push(i);
}
while (sz(s)) {
s.pop();
}
vector<int> nxt(n);
for (int i = n - 1; i >= 0; i--) {
while (sz(s) && flex[s.top()] >= flex[i]) {
s.pop();
}
nxt[i] = sz(s) ? s.top() : n;
s.push(i);
}
for (int pos : inds) {
assert(pos >= 0 && pos < n && flex[pos] != INF);
ll cnt = 1;
mi sum = 0;
map<int, ll> rem_cnt;
rem_cnt[0] = 1;
ll suff = 0;
for (int i = pos - 1; i > prev[pos]; i--) {
suff += a[i];
cnt++;
sum += suff;
rem_cnt[suff % k]++;
}
ll pref = 0;
for (int i = pos; i < nxt[pos]; i++) {
pref += a[i];
ans += mi(cnt) * pref + sum;
ans -= mi(rem_cnt[(k - pref % k) % k]) * flex[pos];
}
}
cout << ans << '\n';
}
```