传送门:HDU - 6832
思路分析
因为边是指数型增长的,每加一条边,比之前所有边加起来都多,所以如果两个点之前已经连通了,这条边就没贡献了,用并查集维护一下最小生成树,然后就是简单的树DP了
AC代码
#include <bits/stdc++.h>
#define fi first
#define se second
#define ll long long
#define pb push_back
#define mp make_pair
#define fun function
#define sz(x) (x).size()
#define lowbit(x) (x)&(-x)
#define all(x) (x).begin(),(x).end()
#define mem(a,b) memset(a,b,sizeof(a))
namespace FastIO {
#define BUF_SIZE 100000
#define OUT_SIZE 100000
bool IOerror=0;
inline char nc() {
static char buf[BUF_SIZE],*p1=buf+BUF_SIZE,*pend=buf+BUF_SIZE;
if(p1==pend) {
p1=buf;
pend=buf+fread(buf,1,BUF_SIZE,stdin);
if(pend==p1) {
IOerror=1;
return -1;
}
}
return *p1++;
}
inline bool blank(char ch) {
return ch==' '||ch=='\n'||ch=='\r'||ch=='\t';
}
template<class T> inline bool read(T &x) {
bool sign=0;
char ch=nc();
x=0;
for(; blank(ch); ch=nc());
if(IOerror)return false;
if(ch=='-')sign=1,ch=nc();
for(; ch>='0'&&ch<='9'; ch=nc())x=x*10+ch-'0';
if(sign)x=-x;
return true;
}
template<class T,class... U>bool read(T& h,U&... t) {
return read(h)&&read(t...);
}
#undef OUT_SIZE
#undef BUF_SIZE
};
using namespace std;
using namespace FastIO;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
const int INF = 0x3f3f3f3f;
const int N = 1e5+10;
const int mod = 1e9+7;
#define int long long
int n,m,f[N];
vector<pair<int,int>>e[N];
ll dis[N][2],num[N][2],ans[N],a[N];
int find(int x) {
return f[x]==x?x:f[x]=find(f[x]);
}
void dfs(int u,int fa){
num[u][a[u]]++;
for(auto x:e[u]){
int v=x.fi;
int w=x.se;
if(v==fa) continue;
dfs(v,u);
num[u][0]+=num[v][0];
num[u][1]+=num[v][1];
dis[u][0]=(dis[u][0] + dis[v][0] + num[v][0]*w%mod)%mod;
dis[u][1]=(dis[u][1] + dis[v][1] + num[v][1]*w%mod)%mod;
ans[u] = (ans[u]+ans[v]) %mod;
}
for(auto x:e[u]){
int v=x.fi;
int w=x.se;
if(v==fa) continue;
ans[u] = (ans[u] + (num[u][0] - num[v][0]) * (dis[v][1] + num[v][1]*w)%mod ) %mod;
ans[u] = (ans[u] + (num[u][1] - num[v][1]) * (dis[v][0] + num[v][0]*w)%mod ) %mod;
}
}
signed main() {
#ifdef xiaofan
freopen("1.in","r",stdin);
freopen("1.out","w",stdout);
#endif
int T;
read(T);
while(T--) {
read(n,m);
for(int i=1; i<=n; i++) {
read(a[i]);
ans[i]=0;
e[i].clear();
num[i][0]=num[i][1]=0;
dis[i][0]=dis[i][1]=0;
f[i]=i;
}
for(int i=1,w=2; i<=m; i++) {
int u,v;
read(u,v);
int fu=find(u);
int fv=find(v);
if(fu!=fv) {
e[u].pb(mp(v,w));
e[v].pb(mp(u,w));
f[fu]=fv;
}
w=w*2%mod;
}
dfs(1,1);
printf("%lld\n",ans[1]);
}
return 0;
}