点击查看代码
#include <bits/stdc++.h>
using namespace std;
char c[3000005];
int w[3000005],fa[3000005],l[3000005],r[3000005];
long long f[3000005][2],val[3000005];
vector<int>a[3000005];
void dp(int n1)
{
int cur=-1;
for(int i=0;i<a[n1].size();i++)
{
dp(a[n1][i]);
if(cur+1==l[a[n1][i]]&&r[a[n1][i]])
{
long long tmp=f[n1][1];
f[n1][1]=max(f[n1][1]+val[a[n1][i]]+w[l[a[n1][i]]]-w[cur],f[n1][0]+val[a[n1][i]]);
f[n1][0]=max(f[n1][0]+f[a[n1][i]][0],tmp+val[a[n1][i]]+w[l[a[n1][i]]]-w[cur]);
cur=r[a[n1][i]];
}
else if(r[a[n1][i]])
{
f[n1][1]=f[n1][0]+val[a[n1][i]];
f[n1][0]=f[n1][0]+f[a[n1][i]][0];
cur=r[a[n1][i]];
}
else
{
f[n1][0]=f[n1][0]+f[a[n1][i]][0];
cur=-1;
}
}
if(r[n1])
{
f[n1][0]=max(f[n1][0],val[n1]);
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int n;
cin>>n;
for(int i=1;i<=n;i++)
{
cin>>c[i];
}
for(int i=1;i<=n;i++)
{
cin>>w[i];
}
int p=0,tot=0;
for(int i=1;i<=n;i++)
{
if(c[i]=='(')
{
tot++;
fa[tot]=p;
a[p].push_back(tot);
p=tot;
l[p]=i;
}
else
{
if(p)
{
r[p]=i;
val[p]=w[r[p]]-w[l[p]];
p=fa[p];
}
}
}
dp(0);
cout<<f[0][0]<<endl;
return 0;
}