Ta sẽ sử dụng kĩ thuật chia để trị để giải quyết bài toán này trong $\mathcal{O}(n \times log(n))$
Xét trên đoạn $[l, r]$, ta chia đoạn này thành 2 nửa, $[l, mid]$ và $[mid + 1, r]$, với $mid$ là điểm nằm giữa $l$ và $r$.
Cố định $i$ bằng cách duyệt $i$ trên đoạn $[l, mid]$, khi này, ta có thể chia thành 4 trường hợp:
- $x, y \in [i, mid]$
- $x, y \in [mid + 1, r]$
- $x \in [i, mid]$, còn $y \in [mid + 1, r]$
- $y \in [i, mid]$, còn $x \in [mid + 1, r]$
Với:
$x = min(a_i, a_{i + 1}, a_{i + 2}, ..., a_{mid})$
$y = max(a_i, a_{i + 1}, a_{i + 2}, ..., a_{mid})$
Gọi $j$ là vị trí đầu tiên thuộc $[mid + 1, r]$ mà $a_j < x$. Tương tự, gọi $k$ là vị trí đầu tiên thuộc $[mid + 1, r]$ mà $a_k > y$. Ta có thể tìm vị trí này bằng cách dùng hai con trỏ
Như vậy, với mọi vị trí thuộc khoảng $[i, min(j, k) - 1]$ đều sẽ rơi vào trường hợp 1, đáp án ở đây là $(min(j, k) - mid - 1) \times x \times y$
Tương tự, với mọi vị trí thuộc khoảng $[max(j, k) + 1, r]$ đều sẽ rơi vào trường hợp 2, đáp án là tổng $min(a_{mid}, a_{mid + 1}, ..., a_{p}) \times max(a_{mid}, a_{mid + 1}, ..., a_{p})$ với $p \in [max(j, k) + 1, r]$ ta có thể tính nhanh tổng ở đây bằng mảng tiền tố.
Còn lại trường hợp 3, và 4. Khi này:
- Nếu $j < k$, ta sẽ lấy $x$ và tổng của $min(a_{mid + 1}, a_{mid + 2}, ..., a_{p})$, với $p \in (j, k]$, ta có thể tính nhanh tổng này bằng mảng tiền tố
- Ngược lại, nếu $j > k$, ta cũng có thể làm tương tự, nhưng với $y$ và tổng của max trên đoạn.
**Code tham khảo:**
```cpp
#include "bits/stdc++.h"
using namespace std;
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
#define int int64_t
int32_t main() {
cin.tie(0)->sync_with_stdio(0);
int n;
cin >> n;
vector<int> a(n);
for (int i = 0; i < n; i++) {
cin >> a[i];
}
int res = 0;
const int inf = 1e9;
vector<int> pre(n), smn(n), smx(n);
const int mod = (int)1e9 + 7;
auto add = [&](int x, int y) {
x += y;
if (x >= mod) {
x -= mod;
}
if (x < 0) {
x += mod;
}
return x;
};
auto mult = [&](int x, int y) -> int {
return 1ll * x * y % mod;
};
function<void(int, int)> dnc = [&](int l, int r) {
if (l == r) {
res += mult(a[l], a[l]);
return;
}
int mid = (l + r) >> 1;
dnc(l, mid);
dnc(mid + 1, r);
int mn = inf, mx = -inf;
smn[mid] = smx[mid] = pre[mid] = 0;
for (int i = mid + 1; i <= r; i++) {
mn = min(mn, a[i]);
mx = max(mx, a[i]);
smn[i] = add(smn[i - 1], mn);
smx[i] = add(smx[i - 1], mx);
pre[i] = add(pre[i - 1], mult(mn, mx));
}
mn = inf, mx = -inf;
int i = mid, j = mid + 1, k = mid + 1;
while (i >= l) {
mn = min(mn, a[i]);
mx = max(mx, a[i]);
while (j <= r && a[j] >= mn) {
j++;
}
while (k <= r && a[k] <= mx) {
k++;
}
if (j < k) {
res = add(res, mult(j - mid - 1, mult(mn, mx)));
res = add(res, mult(mx, add(smn[k - 1], -smn[j - 1])));
res = add(res, add(pre[r], -pre[k - 1]));
} else {
res = add(res, mult(k - mid - 1, mult(mn, mx)));
res = add(res, mult(mn, add(smx[j - 1], -smx[k - 1])));
res = add(res, add(pre[r], -pre[j - 1]));
}
i--;
}
};
dnc(0, n - 1);
cout << res;
}
```