Solutions of Coins flip - MarisaOJ: Marisa Online Judge

Solutions of Coins flip

Select solution language

Write solution here.


User Avatar mark    Created at    0 likes

In general, we want to pick max(A_i, B_i) whenever possible in order to stick close to this hard upper bound. However, when this results in the sum being a multiple of K, we might need to take some losses in order to make it work. Going forward, assume A_i >= B_i. Key observations: 1. If A_i % K == B_i % K, we never take B_i. 2. When we need to change a subarray's parity and it's possible to do so, we only need to take one value from B. Both can be proven using exchange argument. Once you have this foundation the problem becomes a lot more approachable. There are lots of paths you can take from here, but personally I used a vaguely DNQ-like monotonic stack idea. ```cpp void solve() { int n, k; cin >> n >> k; vector<int> a(n), flex(n), inds; for (int i = 0; i < n; i++) { int x, y; cin >> x >> y; if (x < y) { swap(x, y); } a[i] = x; flex[i] = y % k != x % k ? x - y : INF; if (flex[i] != INF) { inds.pb(i); } } mi ans = 0; // Consider subarrays between flex elements inds.insert(begin(inds), -1); inds.pb(n); for (int i = 1; i < sz(inds); i++) { ll cnt = 1; mi sum = 0; map<int, ll> rem_cnt; map<int, mi> rem_sum; rem_cnt[0] = 1; rem_sum[0] = 0; ll psum = 0; for (int j = inds[i - 1] + 1; j < inds[i]; j++) { psum += a[j]; ans += mi(cnt - rem_cnt[psum % k]) * psum; ans -= sum - rem_sum[psum % k]; cnt++; sum += psum; rem_cnt[psum % k] += 1; rem_sum[psum % k] += psum; } } inds.erase(begin(inds)); inds.pop_back(); // For everything else, we take the upper bound of max(a, b), excluding only // the cheapest element that changes our remainder if that's invalid stack<int> s; vector<int> prev(n); for (int i = 0; i < n; i++) { while (sz(s) && flex[s.top()] > flex[i]) { s.pop(); } prev[i] = sz(s) ? s.top() : -1; s.push(i); } while (sz(s)) { s.pop(); } vector<int> nxt(n); for (int i = n - 1; i >= 0; i--) { while (sz(s) && flex[s.top()] >= flex[i]) { s.pop(); } nxt[i] = sz(s) ? s.top() : n; s.push(i); } for (int pos : inds) { assert(pos >= 0 && pos < n && flex[pos] != INF); ll cnt = 1; mi sum = 0; map<int, ll> rem_cnt; rem_cnt[0] = 1; ll suff = 0; for (int i = pos - 1; i > prev[pos]; i--) { suff += a[i]; cnt++; sum += suff; rem_cnt[suff % k]++; } ll pref = 0; for (int i = pos; i < nxt[pos]; i++) { pref += a[i]; ans += mi(cnt) * pref + sum; ans -= mi(rem_cnt[(k - pref % k) % k]) * flex[pos]; } } cout << ans << '\n'; } ```