传送门:CodeForces-1324F
题目描述
You are given a tree consisting of $n$ vertices. A tree is a connected undirected graph with n−1 edges. Each vertex v of this tree has a color assigned to it ($a_v=1$ if the vertex $v$ is white and 0 if the vertex $v$ is black).
You have to solve the following problem for each vertex v: what is the maximum difference between the number of white and the number of black vertices you can obtain if you choose some subtree of the given tree that contains the vertex $v$? The subtree of the tree is the connected subgraph of the given tree. More formally, if you choose the subtree that contains $cnt_w$ white vertices and $cnt_b$ black vertices, you have to maximize $cnt_w−cnt_b$.
输入描述
The first line of the input contains one integer $n$ $(2≤n≤2⋅10^5)$ — the number of vertices in the tree.
The second line of the input contains $n$ integers $a_1,a_2,…,a_n$ $(0≤a_i≤1)$, where $a_i$ is the color of the $i$-th vertex.
Each of the next $n−1$ lines describes an edge of the tree. Edge $i$ is denoted by two integers $u_i$ and $v_i$, the labels of vertices it connects $(1≤u_i,v_i≤n,u_i≠v_i)$.
It is guaranteed that the given edges form a tree.
输出描述
Print $n$ integers $res_1,res_2,…,res_n$, where $res_i$ is the maximum possible difference between the number of white and black vertices in some subtree that contains the vertex $i$.
思路分析
$dp[i]$表示以$i$为根的答案,但是需要每一个点的答案,所以需要换根
样例输入
9
0 1 1 1 0 0 0 0 1
1 2
1 3
3 4
3 5
2 6
4 7
6 8
5 9
样例输出
2 2 2 2 2 1 1 0 2
AC代码
#include <functional>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <iomanip>
#include <vector>
#include <string>
#include <cstdio>
#include <queue>
#include <stack>
#include <cmath>
#include <map>
#include <set>
#if __cplusplus >= 201103L
#include <unordered_map>
#include <unordered_set>
#endif
#define ls x<<1
#define rs x<<1|1
#define fi first
#define se second
#define ll long long
#define pb push_back
#define mp make_pair
#define fun function
#define vi vector<int>
#define lowbit(x) x&(-x)
#define pii pair<int,int>
#define all(x) x.begin(),x.end()
#define mem(a,b) memset(a,b,sizeof(a))
#define IOS ios_base::sync_with_stdio(0); cin.tie(0);cout.tie(0);
using namespace std;
const int INF = 0x3f3f3f3f;
const int N=1e6+10;
int dp[N],ans[N];
vector<int>e[N];
void dfs1(int u,int fa) {
for(auto v:e[u]) {
if(v==fa)continue;
dfs1(v,u);
dp[u]+=max(0,dp[v]);
}
}
void dfs2(int u,int fa) {
ans[u]=dp[u];
for(auto v:e[u]) {
if(v==fa)continue;
dp[u]-=max(0,dp[v]);
dp[v]+=max(0,dp[u]);
dfs2(v,u);
dp[v]-=max(0,dp[u]);
dp[u]+=max(0,dp[v]);
}
}
int main() {
IOS;
#ifdef xiaofan
freopen("1.in","r",stdin);
freopen("1.out","w",stdout);
#endif
int n;
cin>>n;
for(int i=1; i<=n; i++) {
cin>>dp[i];
if(!dp[i]) dp[i]=-1;
}
for(int i=0; i<n-1; i++) {
int u,v;
cin>>u>>v;
e[u].push_back(v);
e[v].push_back(u);
}
dfs1(1,-1);
dfs2(1,-1);
for(int i=1; i<=n; i++)
cout<<ans[i]<<" \n"[i==n];
return 0;
}