提交时间:2022-07-20 12:43:22

运行 ID: 52984

#include<cstdio> #include<iostream> #include<algorithm> #define ll long long using namespace std; const int N=5e5+5,mod=1e9+7; int n,tot,a[N],b[N],pos[N],head[N],nex[N<<1],v[N<<1]; ll ans,tag[N<<2],t[N<<2],sum[N<<2],cnt[N<<2]; void Add(int x,int y) { nex[++tot]=head[x];head[x]=tot;v[tot]=y; nex[++tot]=head[y];head[y]=tot;v[tot]=x; } void pushdown(int k) { if(tag[k]!=0) { tag[k<<1]+=tag[k];tag[k<<1|1]+=tag[k]; t[k<<1]+=tag[k]*sum[k<<1];t[k<<1|1]+=tag[k]*sum[k<<1|1]; t[k<<1]%=mod;t[k<<1|1]%=mod; tag[k]=0; } } void pushup(int k) { t[k]=(t[k<<1]+t[k<<1|1])%mod; sum[k]=(sum[k<<1]+sum[k<<1|1])%mod; cnt[k]=cnt[k<<1]+cnt[k<<1|1]; } void update(int k,int l,int r,int x,int val) { if(l==r) { cnt[k]=1; sum[k]=b[l];t[k]=1ll*val*b[l]%mod; return ; } pushdown(k); int mid=(l+r)>>1; if(x<=mid) update(k<<1,l,mid,x,val); else update(k<<1|1,mid+1,r,x,val); pushup(k); } void Del(int k,int l,int r,int x) { if(l==r) { cnt[k]=sum[k]=t[k]=tag[k]=0; return; } pushdown(k); int mid=(l+r)>>1; if(x<=mid) Del(k<<1,l,mid,x); else Del(k<<1|1,mid+1,r,x); pushup(k); } void modify(int k,int l,int r,int x,int y,int val) { if(x<=l&&r<=y) { tag[k]+=val;t[k]=(t[k]+1ll*sum[k]*val%mod)%mod; return ; } if(l>y||r<x) return ; pushdown(k); int mid=(l+r)>>1; modify(k<<1,l,mid,x,y,val);modify(k<<1|1,mid+1,r,x,y,val); pushup(k); } int query(int k,int l,int r,int x,int y) { if(x<=l&&r<=y) return cnt[k]; if(l>y||r<x) return 0; pushdown(k); int mid=(l+r)>>1; return query(k<<1,l,mid,x,y)+query(k<<1|1,mid+1,r,x,y); } void dfs(int x,int fa) { int k=query(1,1,n,1,pos[x])+1; update(1,1,n,pos[x],k);modify(1,1,n,pos[x]+1,n,1); ans=(ans+t[1])%mod; // printf("%d\n",t[1]); for(int i=head[x];i;i=nex[i]) { int y=v[i]; if(y==fa) continue; dfs(y,x); } Del(1,1,n,pos[x]);modify(1,1,n,pos[x]+1,n,-1); } int main() { int x,y; scanf("%d",&n); for(int i=1;i<=n;i++) { scanf("%d",&a[i]),b[i]=a[i]; ans=(ans+a[i])%mod; } for(int i=1;i<n;i++) scanf("%d%d",&x,&y),Add(x,y); sort(b+1,b+1+n); for(int i=1;i<=n;i++) pos[i]=lower_bound(b+1,b+1+n,a[i])-b; for(int i=1;i<=n;i++) dfs(i,0); // dfs(3,0); ans=ans*((mod+1)>>1)%mod; printf("%lld",ans); return 0; }