[题解]P2671 [NOIP2015 普及组] 求和
P2671 [NOIP2015 普及组] 求和
可以发现我们对相同颜色且编号奇偶性相同的元素归为一组,组内的元素两两都满足题目条件,且这样可以不重不漏覆盖所有答案。
设分完组之后,某一组内的元素编号分别是\(a_1,a_2,\dots,a_q\),数字分别是\(b_1,b_2,\dots,b_q\),则根据题意,该组的答案是:
\[\large{\sum\limits_{1\le j<i\le n}(a_i+a_j)\times(b_i+b_j)}
\]
对于每个\(i\),答案是:
\[\large{\sum\limits_{j=1}^{i-1}b_i\times c_i+b_i\times c_j+b_j\times c_i+b_j\times c_j}
\]
所以维护\(b\)、\(c\)、\(b\times c\)三个前缀和数组即可。这样对于每个\(i\)都可以\(O(1)\)求出。总时间复杂度\(O(n)\)。
点击查看代码
#include<bits/stdc++.h>
#define int long long
#define mod 10007
#define N 100010
#define M 100010
using namespace std;
int n,m,col[N],a[N],nn[2*M],ans;
vector<int> b[2*M],c[2*M],preb[2*M],prec[2*M],pre[2*M];
signed main(){
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=1;i<=n;i++) cin>>col[i];
for(int i=1;i<=(m<<1);i++) b[i].emplace_back(0),c[i].emplace_back(0);
for(int i=1,pos;i<=n;i++)
pos=col[i]+(1-i%2)*m,
b[pos].emplace_back(a[i]),
c[pos].emplace_back(i),
nn[pos]++;
for(int i=1;i<=(m<<1);i++){
pre[i].resize(nn[i]+1);
preb[i].resize(nn[i]+1);
prec[i].resize(nn[i]+1);
for(int j=1;j<=nn[i];j++){
preb[i][j]=(preb[i][j-1]+b[i][j])%mod;
prec[i][j]=(prec[i][j-1]+c[i][j])%mod;
pre[i][j]=(pre[i][j-1]+b[i][j]*c[i][j]%mod)%mod;
}
for(int j=2;j<=nn[i];j++){
ans=(ans+b[i][j]*c[i][j]%mod*(j-1)%mod)%mod;
ans=(ans+pre[i][j-1])%mod;
ans=(ans+b[i][j]*prec[i][j-1]%mod)%mod;
ans=(ans+c[i][j]*preb[i][j-1]%mod)%mod;
}
}
cout<<ans<<"\n";
return 0;
}
代码稍稍有些繁琐了,其实前缀和什么的,都可以分组的同时记录,由于只需要用到上一个位置的前缀和,所以第二位可以去掉。可以参考此题解的代码。
题解里也有用纯数学推导的,这里也写一下:
依旧设当前组的元素编号分别是\(a_1,a_2,\dots,a_q\),数字分别是\(b_1,b_2,\dots,b_q\)。这样该组的答案是:
\[\begin{aligned}
&\sum\limits_{1\le j<i\le n}(a_i+a_j)\times(b_i+b_j)\\
=\ &(a_1+a_2)\times(b_1+b_2)+(a_1+a_3)\times(b_1+b_3)+\dots+(a_2+a_3)\times(b_2+b_3)+\dots+(a_{q-1}+a_q)\times(b_{q-1}+b_q)\\
=\ &a_1\times(b_1+b_2+b_1+b_3+\dots+b_1+b_q)+a_2\times(b_2+b_1+b_2+b_3+\dots+b_2+b_q)+\dots+a_q\times(b_q+b_1+b_q+b_2+\dots+b_q+b_{q-1})\\
=\ &a_1\times((q-2)\times b_1+\sum\limits_{i=1}^{q}b_i)+a_2\times((q-2)\times b_2+\sum\limits_{i=1}^{q}b_i)+\dots+a_q\times((q-2)\times b_q+\sum\limits_{i=1}^{q}b_i)
\end{aligned}\]
于是我们维护每一组的\(\sum\limits_{i=1}^{q}b_i\),一次遍历统计答案即可。
点击查看代码
#include<bits/stdc++.h>
#define N 100010
#define M 100010
#define mod 10007
#define int long long
using namespace std;
int n,m,a[N],col[N],q[M][2],sum[M][2],ans;
signed main(){
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=1;i<=n;i++){
cin>>col[i];
q[col[i]][i&1]++;
sum[col[i]][i&1]=(sum[col[i]][i&1]+a[i])%mod;
}
for(int i=1;i<=n;i++){
ans=(ans+i%mod*(a[i]*(q[col[i]][i&1]-2)%mod+mod+sum[col[i]][i&1])%mod)%mod;
}
cout<<ans<<"\n";
return 0;
}