参考:树分治论文
传送门:poj 1741 Tree
题意:给你一棵N(N<=10000)个节点的带权数,定义dist(u,v)为u,v两点之间的距离,再给定一个K
(1<=K<=10^9),如果对于两个不同的节点a,b,如果满足dist(a,b,)<=k,则称(a,b)为合法点对,
求合法点对数目
思路:我们知道一条路径要么过根节点,要么在一棵子树中,所以我们可以利用分治
路径在子树中的情况只需递归处理即可, 下面我们来分析如何处理路径过根结点的情况。
记depth[i]为点i到根节点的路径长度,Belong[i]=x(x为根节点的某个儿子,且节点i在以X为根的子树中)
那么我们要统计的就是:满足depth[i]+depth[j]<=k且Belong[i]!=Belong[j]的(i,j)个数=
depth[i]+depth[j]<=k的个数-满足depth[i]+depth[j]<=k且Belong[i]=Belong[j]的(i,j)个数
而对于这两个部分, 都是要求出满足Ai+Aj<=k的(i,j)的对数
将 A 排序后利用单调性我们很容易得出一个O(n)的算法,所以我们
可以用O(NlogN)的时间来解决这个问题。
所以总的时间复杂度为O(NlogN^2)
#include <map>
#include <set>
#include <stack>
#include <queue>
#include <cmath>
#include <ctime>
#include <vector>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
using namespace std;
#define INF 0x3f3f3f3f
#define inf -0x3f3f3f3f
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define mem0(a) memset(a,0,sizeof(a))
#define mem1(a) memset(a,-1,sizeof(a))
#define mem(a, b) memset(a, b, sizeof(a))
#define MP(x,y) make_pair(x,y)
typedef long long ll;
void fre() { freopen("input.in", "r", stdin); freopen("output.out", "w", stdout); }
template <class T1, class T2>inline void gmax(T1 &a, T2 b) { if (b>a)a = b; }
template <class T1, class T2>inline void gmin(T1 &a, T2 b) { if (b<a)a = b; }
typedef pair<int,int>PI;
const int maxn=10010;
vector<PI>G[maxn];
vector<int>dep;
int dis[maxn],k,size[maxn],f[maxn],Count,root;//Count表示当前子树的结点的总个数
bool Del[maxn];
long long ans=0;
void getroot(int u,int pre){
size[u]=1,f[u]=0;
for(int i=0;i<G[u].size();++i){
int v=G[u][i].first;
if(v!=pre && !Del[v]){
getroot(v,u);
size[u]+=size[v];
f[u]=max(f[u],size[v]);
}
}
f[u]=max(f[u],Count-size[u]);
if(f[u]<f[root]) root=u;
}
void getdep(int u,int pre){ //这里还需要重新计算每个子树的size
dep.push_back(dis[u]);
size[u]=1;
for(int i=0;i<G[u].size();++i){
int v=G[u][i].first;
if(v!=pre && !Del[v]){
dis[v]=dis[u]+G[u][i].second;
getdep(v,u);
size[u]+=size[v];
}
}
}
long long cal(int u,int bg_depth){
dep.clear();dis[u]=bg_depth;
getdep(u,0);
sort(dep.begin(),dep.end());
long long ret=0;
for(int l=0,r=dep.size()-1;l<r;){
if(dep[l]+dep[r]<=k) ret+=(r-l),l++;
else r--;
}
return ret;
}
void work(int u){
ans+=cal(u,0);
Del[u]=true;
for(int i=0;i<G[u].size();i++){
int v=G[u][i].first;
if(!Del[v]){
ans-=cal(v,G[u][i].second);
f[0]=Count=size[v];
getroot(v,root=0);
work(root);
}
}
}
int main(){
int n;
while(scanf("%d%d",&n,&k)!=EOF){
if(n==0&&k==0)
break;
for(int i=1;i<=n;i++) G[i].clear();
int u,v,w;
for(int i=1;i<n;i++){
scanf("%d%d%d",&u,&v,&w);
G[u].push_back(MP(v,w));
G[v].push_back(MP(u,w));
}
ans=0;
f[0]=Count=n;
getroot(1,root=0);
memset(Del,false,sizeof(Del));
work(root);
printf("%lld\n",ans);
}
return 0;
}
思路二:首先,我们先对每个点算出它的子节点的子树的最大是哪一个结点
然后每次合并都是把其他结点数较小的合并到大的上面,可以发现每个点最多被合并logn次
要先计算贡献再进行合并
/* 思路: */
#include <map>
#include <set>
#include <stack>
#include <queue>
#include <cmath>
#include <ctime>
#include <vector>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
using namespace std;
#define INF 0x3f3f3f3f
#define inf -0x3f3f3f3f
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define mem0(a) memset(a,0,Countof(a))
#define mem1(a) memset(a,-1,Countof(a))
#define mem(a, b) memset(a, b, Countof(a))
#define MP(x,y) make_pair(x,y)
typedef long long ll;
typedef pair<int,int>PI;
const int maxn=10010;
vector<PI>G[maxn];
long long ans;
int k,Count[maxn],length[maxn],son[maxn];
int size[maxn*30],root[maxn],tot,cnt[maxn*30],val[maxn*30],ch[maxn*30][2],fa[maxn*30];
//按值的大小建立Splay
void Newnode(int &now,int father,int k,int siz){
now=++tot;
cnt[now]=siz,size[now]=siz,val[now]=k,fa[now]=father,ch[now][0]=ch[now][1]=0;
}
void dfs1(int u,int pre){
Count[u]=1,son[u]=-1;
for(int i=0;i<G[u].size();++i){
int v=G[u][i].first;
if(v!=pre){
dfs1(v,u);
Count[u]+=Count[v];
if(son[u]==-1||Count[son[u]]<Count[v])
son[u]=v,length[u]=G[u][i].second;
}
}
}
void pushup(int x){
size[x]=size[ch[x][0]]+size[ch[x][1]]+cnt[x];
}
//旋转,kind为1为右旋,kind为0为左旋
void Rotate(int x,int kind){
int y=fa[x];
ch[y][!kind]=ch[x][kind];
fa[ch[x][kind]]=y;
//如果父节点不是根结点,则要和父节点的父节点连接起来
if(fa[y])
ch[fa[y]][ch[fa[y]][1]==y]=x;
fa[x]=fa[y];
ch[x][kind]=y;
fa[y]=x;
pushup(y);
}
//Splay调整,将根为now的子树调整为goal
void Splay(int now,int goal,int Belong){
while(fa[now]!=goal){
if(fa[ fa[now] ]==goal)
Rotate(now,ch[ fa[now] ][0]==now);
else{
int pre=fa[now],kind=ch[ fa[pre] ][0]==pre; //左儿子为1,右儿子为0
if(ch[pre][kind]==now){ //两个方向不同
Rotate(now,!kind);
Rotate(now,kind);
}
else{ //两个方向相同
Rotate(pre,kind);
Rotate(now,kind);
}
}
}
if(goal==0) root[Belong]=now;
pushup(now);
}
int getans(int u,int dep,int dis){
if(!u) return 0;
int ret=0;
if(val[u]+dis-2*dep<=k){
ret+=size[ch[u][0]]+cnt[u];
ret+=getans(ch[u][1],dep,dis);
}
else
ret+=getans(ch[u][0],dep,dis);
return ret;
}
int Insert(int now,int k,int Belong,int siz){
while(ch[now][val[now]<k]){
//不重复插入
if(val[now]==k){
Splay(now,0,Belong);
cnt[now]+=siz;
pushup(now);
return 0;
}
now=ch[now][val[now]<k];
}
if(val[now]==k){
Splay(now,0,Belong);
cnt[now]+=siz;
pushup(now);
return 0;
}
Newnode(ch[now][k>val[now]],now,k,siz);
//将新插入的结点更新至根结点
Splay(ch[now][k>val[now]],0,Belong);
return 1;
}
void query(int now,int u,int dep){
if(!u) return ;
if(val[u]!=-INF&&val[u]!=INF)
ans+=1LL*(getans(now,dep,val[u])-1)*cnt[u];
query(now,ch[u][0],dep);
query(now,ch[u][1],dep);
}
void merge(int u,int Belong){
if(!u) return ;
if(val[u]!=-INF&&val[u]!=INF)
Insert(root[Belong],val[u],Belong,cnt[u]);
merge(ch[u][0],Belong);
merge(ch[u][1],Belong);
}
void dfs2(int u,int dep,int pre){
if(son[u]!=-1){
dfs2(son[u],dep+length[u],u);
root[u]=root[son[u]];
ans+=getans(root[u],dep,dep)-1;
Insert(root[u],dep,u,1);
}
else{
root[u]=0;
Newnode(root[u],0,-INF,1);
Newnode(ch[root[u]][1],root[u],INF,1);//头尾各加入一个空位
Newnode(ch[ch[root[u]][1]][0],ch[root[u]][1],dep,1);
pushup(ch[root[u]][1]);
pushup(root[u]);
}
for(int i=0;i<G[u].size();++i){
int v=G[u][i].first;
if(v!=pre && v!=son[u]){
dfs2(v,dep+G[u][i].second,u);
query(root[u],root[v],dep);
merge(root[v],u);
}
}
}
int main(){
int n;
while(scanf("%d%d",&n,&k)!=EOF){
tot=0;
if(n==0&&k==0)
break;
for(int i=1;i<=n;i++) G[i].clear();
int u,v,w;
for(int i=1;i<n;i++){
scanf("%d%d%d",&u,&v,&w);
G[u].push_back(MP(v,w));
G[v].push_back(MP(u,w));
}
dfs1(1,0);
ans=0;
dfs2(1,0,0);
printf("%lld\n",ans);
}
return 0;
}