[CCO 2024] Heavy Light Decomposition
前置知识
分块,DP
简要题意
定义“好的数组”为一个数组内交替出现“轻元素”和“重元素”,轻元素即其在这个数组内是唯一的,重元素即其在数组内出现多次。
有 $n$ 个正整数 $a_i$,求其有多少种划分方案能使划分后的子数组均为好的数组。
分析样例
我们首先要搞懂一个东西,就是划分后好的数组,是指这个子数组是好的,在这个子数组内的轻重元素与原数组并没有关系,每个子数组是互相独立的。
对于样例一,其划分方案如下:
- $[1], [2], [3], [2], [3]$
- $[1], [2, 3, 2], [3]$
- $[1], [2], [3, 2, 3]$
- $[1, 2, 3, 2], [3]$
对于样例二,其划分方案如下:
- $[1], [2], [1], [3], [1]$
- $[1, 2, 1], [3], [1]$
- $[1, 2, 1, 3], [1]$
- $[1], [2], [1, 3, 1]$
- $[1], [2, 1, 3, 1]$
- $[1, 2, 1, 3, 1]$
不明白的建议手推一下。
思路分析
考虑转移
我们定义 $dp[i]$ 为前 $i$ 个元素的合法划分的方案数。
那么,转移方程很显然:$dp[i] = \sum dp[j]$,其中 $j < i$ 且 子数组 $[j + 1, i]$ 是好数组。
实际上含义就是在 $j$ 处划分,新增一个子数组 $[j + 1, i]$,方案累加前 $j$ 个元素的方案数。
考虑好数组的约束
如果 $[j + 1, i]$ 为好的数组,那么需要满足:
1. 类型交替:即数组内的元素轻重交替。
2. 奇偶性约束:如果重元素第一次出现在奇数位,那么奇数位全是重元素,反之。
因此我们如果直接枚举所有的 $j$ 去验证 $[j + 1, i]$ 是否为好数组,时间复杂度为 $O(n ^ 2)$。
考虑优化
对于当前的位置 $i$,我们设其元素大小为 $v$,用 $odd[v]$ 和 $even[v]$ 来记录 $v$ 在奇数和偶数位最近的出现位置,这样的话可以确定 $j$ 的下界。
为了保证子数组 $[j + 1, i]$ 满足类型交替,需要避免 $v$ 元素在数组内出现奇偶性冲突,那么若 $j + 1 < \min(odd[v], even[v])$ 的话,则会使其冲突。
因此我们使 $minL$ 取所有元素 $min(odd[v], even[v]) + 1$ 的最大值,因此 $j > minL - 1$。
然后是最重要的分块,我们将原数组分块,每个块维护两个核心内容,$sum[k][b]$ 表示 $b$ 块满足在奇偶性 $k$ 下的合法的 $dp[j]$ 之和,$kpos[k][b]$ 是在 $k$ 的奇偶性下块 $b$ 是否满足。另外,为了维护分块时的单个元素, 我们维护 $pos[k][i]$ 是单个位置的 $i$ 是否满足奇偶性 $k$。
当我们查询 $[l, r]$ 内符合条件的 $dp[j]$ 之和时,对于整块,只需要判断 $kpos$ 是否有效,然后累加 $sum$ 即可,对于单个的块边缘的元素,则需要满足 $kpos$ 和 $pos$,有效则累加 $dp[j]$。
在区间更新时,只需要标记区间有效和无效,在完整的块上更新 $kpos$,零散的元素更新 $pos$ 和 $sum$。
时间复杂度
分块的单次查询和更新的时间复杂度为 $O(\sqrt{n})$,时间复杂度为 $O(n \sqrt n)$。
简单卡常即可,最慢的点才两秒出头,对于四秒的时间限制完全够用。
代码
#include<iostream>
#include<cstdio>
#include<cmath>
#include<vector>
#include<unordered_map>
#include<cstring>
using namespace std;
const int MOD = 1e6 + 3;
const int MAXN = 5 * 1e5 + 5;
int kpos[2][MAXN], sum[2][MAXN], pos[2][MAXN];
int L[MAXN], R[MAXN], id[MAXN];
int dp[MAXN], even[MAXN], odd[MAXN];
int a[MAXN], pre[MAXN];
pair<int, int> lst[MAXN];
int n, tot = 0, B;
inline int read(){
int x = 0, f = 1;
char ch = getchar();
while(ch < '0' || ch > '9'){
if(ch == '-') f = -1;
ch = getchar();
}
while(ch>='0' && ch<='9')
x = x * 10 + ch - '0', ch = getchar();
return x * f;
}
inline void update(int k, int l, int r, int x){
if(l > r) return;
int kl = id[l], kr = id[r];
if(kl != kr){
for(int i = kl + 1;i < kr;++ i){
kpos[k][i] += x;
}
if(l == L[kl]){
kpos[k][kl] += x;
}
else{
for(int i = l;i <= R[kl];++ i){
if(pos[k][i]){
sum[k][kl] = (sum[k][kl] + dp[i - 1]) % MOD;
}
pos[k][i] += x;
if(pos[k][i]){
sum[k][kl] = (sum[k][kl] - dp[i - 1] + MOD) % MOD;
}
}
}
if(r == R[kr]){
kpos[k][kr] += x;
}
else{
for(int i = L[kr];i <= r;++ i){
if(pos[k][i]){
sum[k][kr] = (sum[k][kr] + dp[i - 1]) % MOD;
}
pos[k][i] += x;
if(pos[k][i]){
sum[k][kr] = (sum[k][kr] - dp[i - 1] + MOD) % MOD;
}
}
}
}
else{
if(l == L[kl] && r == R[kl]){
kpos[k][kl] += x;
}
else {
for(int i = l;i <= r;++ i){
if(pos[k][i]){
sum[k][kl] = (sum[k][kl] + dp[i - 1]) % MOD;
}
pos[k][i] += x;
if(pos[k][i]){
sum[k][kl] = (sum[k][kl] - dp[i - 1] + MOD) % MOD;
}
}
}
}
}
inline int query(int k, int l, int r){
if(l > r) return 0;
int res = 0;
int kl = id[l], kr = id[r];
if(kl != kr){
for(int i = kl + 1;i < kr;++ i){
if(!kpos[k][i]){
res = (res + sum[k][i]) % MOD;
}
}
if(l == L[kl]){
if(!kpos[k][kl]){
res = (res + sum[k][kl]) % MOD;
}
}
else{
for(int i = l; i <= R[kl]; i++){
if(!pos[k][i] && !kpos[k][kl]){
res = (res + dp[i - 1]) % MOD;
}
}
}
if(r == R[kr]){
if(!kpos[k][kr]){
res = (res + sum[k][kr]) % MOD;
}
}
else{
for(int i = L[kr];i <= r;++ i){
if (!pos[k][i] && !kpos[k][kr]){
res = (res + dp[i - 1]) % MOD;
}
}
}
}
else{
if(l == L[kl] && r == R[kl]){
if(!kpos[k][kl]){
res = (res + sum[k][kl]) % MOD;
}
}
else{
for(int i = l;i <= r;++ i){
if (!pos[k][i] && !kpos[k][kl]){
res = (res + dp[i - 1]) % MOD;
}
}
}
}
return res;
}
int main(){
freopen("digit.in", "r", stdin);
freopen("digit.out", "w", stdout);
n = read();
B = 300;
for(int i = 1; i <= n; i += B){
L[++ tot] = i;
R[tot] = min(n, i + B - 1);
for(int j = i;j <= R[tot];++ j){
id[j] = tot;
}
}
// for(int i = 1;i <= n;++ i){
// cout << i << ' ' << id[i] << '\n;'
// }
dp[0] = 1;
for(int i = 1;i <= n;++ i) a[i] = read();
int leven = 0, lodd = 0, minL = 1;
memset(lst, -1, sizeof(lst));
for(int i = 1;i <= n;++ i){
sum[0][id[i]] = (sum[0][id[i]] + dp[i - 1]) % MOD;
sum[1][id[i]] = (sum[1][id[i]] + dp[i - 1]) % MOD;
dp[i] = dp[i - 1];
int v = a[i];
if(lst[v].second != -1){
update(lst[v].second & 1, lst[v].first + 1, lst[v].second, -1);
}
if(i & 1){
lodd = max(lodd, odd[v]);
odd[v] = i;
}
else{
leven = max(leven, even[v]);
even[v] = i;
}
lst[v] = {pre[v], i};
update(lst[v].second & 1, lst[v].first + 1, lst[v].second, 1);
pre[v] = i;
minL = max(minL, min(odd[v], even[v]) + 1);
if(leven < lodd){
dp[i] = (dp[i] + query(1, max(minL, leven + 1), i - 1)) % MOD;
}
else if (lodd < leven){
dp[i] = (dp[i] + query(0, max(minL, lodd + 1), i - 1)) % MOD;
}
}
printf("%d", dp[n] % MOD);
return 0;
}