51nod 1601 完全图的最小生成树计数 Trie+kruskal

题意:给定一个长度为n的数组a[1..n],有一幅完全图,满足(u,v)的边权为a[u] xor a[v]
求边权和最小的生成树,你需要输出边权和还有方案数对1e9+7取模的值。

由于边权是xor得到,容易想到用trie统计。。
按照当前最高位0/1将当前区间内的点分成两个部分s/t,那么答案肯定是s的最小生成树+t的最小生成树+s-t的最小边,s-t最小边用trie统计,最小生成树递归处理。
那么方案数的话就是每次那个连接两个块之间的最小边的数量,所以trie树统计一下节点个数就好。
字典树那个地方每次查询一个数,尽量使得当前位置相同就好,最后记得记录一下,每个位可能有多个点(多个数),会对方案造成贡献。
好像是经典套路?

#include<cstdio>
#include<algorithm>
#include<cstring>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
const int N=1e5+5;
const int mo=1e9+7;
const int inf=0x3f3f3f3f;
int n,cnt,tot,a[N],s[N],t[N],fac[N];
typedef long long ll;
ll sum;
struct node
{
    int cnt,next[2];
}ch[N*31];
inline void clear()
{
    fo(i,0,tot)
        ch[i].next[0]=ch[i].next[1]=ch[i].cnt=0;
    tot=0;
}
inline int pow(int a,int b)
{
    int ret=1;
    while (b)
    {
        if (b&1)ret=1ll*ret*a%mo;
        a=1ll*a*a%mo;
        b>>=1;
    }
    return ret;
}
inline void ins(int x)
{
    int p=0;
    fd(i,30,0)
    {
        int y=(x>>i)&1;
        if (!ch[p].next[y])
            ch[p].next[y]=++tot;
        p=ch[p].next[y];
    }
    ch[p].cnt++;
}
inline pair<int,int> find(int x)
{
    int p=0,ans=0,y;
    fd(i,30,0)
    {
        y=(x>>i)&1;
        if (ch[p].next[y])p=ch[p].next[y],ans|=y<<i;
        else p=ch[p].next[y^1],ans|=(y^1)<<i;
    }
    return make_pair(ans^x,ch[p].cnt);
}
inline void solve(int l,int r,int dep)
{
    if (l>=r)return;
    if (dep<0)
    {
        if (r-l+1>=2)cnt=1ll*cnt*pow(r-l+1,r-l-1)%mo;
        return;
    }
    int cnt1=0,cnt2=0;
    fo(i,l,r)
        if ((a[i]>>dep)&1)s[cnt1++]=a[i];
        else t[cnt2++]=a[i];
    fo(i,0,cnt1-1)a[l+i]=s[i];
    fo(i,0,cnt2-1)a[l+cnt1+i]=t[i];
    clear();
    pair<int,int>tmp;
    int ans=inf,tot=0;
    fo(i,0,cnt2-1)ins(t[i]);
    fo(i,0,cnt1-1)
    {
        tmp=find(s[i]);
        if (tmp.first<ans)ans=tmp.first,tot=tmp.second;
        else if (tmp.first==ans)
            tot+=tmp.second;
    }
    if (sum!=inf&&tot)sum+=ans,cnt=1ll*tot*cnt%mo;
    solve(l,l+cnt1-1,dep-1);
    solve(l+cnt1,r,dep-1);
}
int main()
{

    scanf("%d",&n);
    fac[0]=cnt=1;
    fo(i,1,n)fac[i]=1ll*fac[i-1]*i%mo;
    fo(i,1,n)scanf("%d",&a[i]);
    solve(1,n,30);
    printf("%lld\n%d\n",sum,cnt);
}
    原文作者:Trie树
    原文地址: https://blog.csdn.net/qq_35866453/article/details/77899285
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞