提交时间:2022-07-20 12:02:30
运行 ID: 52836
#include <bits/stdc++.h> using namespace std; const int N = 5e5+5,mod = 1e9+7; int n,t1,t2,sum,tot,cnt,a[N],b[N],fa[N],dep[N],head[N],nxt[2*N],to[2*N]; void dfs(int x) { for(int i = head[x]; i; i = nxt[i]) { if(to[i] == fa[x]) continue; fa[to[i]] = x; dep[to[i]] = dep[x] + 1; dfs(to[i]); } } void lca(int x,int y) { if(dep[x] < dep[y]) swap(x,y); while(dep[x] > dep[y]) { b[++cnt] = a[x]; x = fa[x]; } while(x != y) { b[++cnt] = a[x]; b[++cnt] = a[y]; x = fa[x],y = fa[y]; } b[++cnt] = a[x]; } int read() { char ch = getchar(); int ret = 0,f = 1; while(ch == ' ' || ch == '\n') ch = getchar(); while(ch == '-') { f *= -1; ch = getchar(); } while('0' <= ch && ch <= '9') { ret = ret * 10 + ch - '0'; ch = getchar(); } return f * ret; } void add(int x,int y) { ++tot; to[tot] = y,nxt[tot] = head[x],head[x] = tot; } int main() { n = read(); for(int i = 1; i <= n; ++i) { a[i] = read(); sum += a[i]; } for(int i = 1; i < n; ++i) { t1 = read(),t2 = read(); add(t1,t2); add(t2,t1); } dfs(1); for(int i = 1; i < n; ++i) for(int j = i + 1; j <= n; ++j) { cnt = 0; lca(i,j); sort(b + 1,b + cnt + 1); for(int k = 1; k <= cnt; k++) sum = (sum + 1LL * k * b[k] % mod) % mod; } printf("%d\n",sum); return 0; }