题目:http://www.lydsy.com/JudgeOnline/problem.php?id=3451
这题目实在是太神了!由于如果某点x出现在y的子树上贡献1的消费,那么说明x是路径(x,y)上最早选到的,那么答案就是sigma(1/dist(u,v)),然后点分治+FFT统计之,O(n log^2 n)
代码:
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <vector>
using namespace std ;
#define travel( x ) for ( edge *p = head[ x ] ; p ; p = p -> next )
#define rep( i , x ) for ( int i = 0 ; i ++ < x ; )
#define REP( i , l , r ) for ( int i = l ; i <= r ; ++ i )
#define Rep( i , x ) for ( int i = 0 ; i < x ; ++ i )
#define com( a , b ) ( com ) { a , b }
typedef long long ll ;
const int maxn = 101000 ;
const double PI = acos( -1.0 ) ;
struct com {
double a , b ;
com operator * ( const com &x ) const {
return com( ( a * x.a - b * x.b ) , ( b * x.a + a * x.b ) ) ;
}
com operator + ( const com &x ) const {
return com( ( a + x.a ) , ( b + x.b ) ) ;
}
com operator - ( const com &x ) const {
return com( ( a - x.a ) , ( b - x.b ) ) ;
}
} A[ maxn ] ;
int tra[ maxn ] ;
inline void FFT( com *a , int N , bool flag ) {
Rep( i , N ) tra[ i ] = 0 ;
for ( int i = 1 , j = N >> 1 ; i < N ; i <<= 1 , j >>= 1 ) for ( int k = i ; k < ( i << 1 ) ; ++ k ) tra[ k ] = j ;
for ( int i = 1 ; i < N ; i <<= 1 ) for ( int j = 0 ; j < i ; ++ j ) tra[ j + i ] |= tra[ j ] ;
Rep( i , N ) A[ i ] = a[ tra[ i ] ] ;
double pi = flag ? PI : ( - PI ) ;
com e , w , rec , ret ;
for ( int i = 1 ; i < N ; i <<= 1 ) {
e = com( cos( ( 2.0 * pi ) / double( i << 1 ) ) , sin( ( 2.0 * pi ) / double( i << 1 ) ) ) , w = com( 1 , 0 ) ;
for ( int j = 0 ; j < i ; ++ j , w = w * e ) {
for ( int k = j ; k < N ; k += ( i << 1 ) ) {
rec = A[ k ] , ret = w * A[ k + i ] ;
A[ k ] = rec + ret , A[ k + i ] = rec - ret ;
}
}
}
if ( ! flag ) Rep( i , N ) A[ i ].a /= double( N ) ;
Rep( i , N ) a[ i ] = A[ i ] ;
}
com a[ maxn ] , b[ maxn ] , c[ maxn ] ;
ll ans[ maxn ] ;
struct edge {
edge *next ;
int t ;
} E[ maxn << 1 ] ;
edge *pt = E , *head[ maxn ] ;
inline void add( int s , int t ) {
pt -> t = t , pt -> next = head[ s ] ; head[ s ] = pt ++ ;
}
inline void addedge( int s , int t ) {
add( s , t ) , add( t , s ) ;
}
bool del[ maxn ] ;
int n , size[ maxn ] , root , rt ;
void gets( int now , int fa ) {
size[ now ] = 1 ;
travel( now ) if ( ! del[ p -> t ] && p -> t != fa ) {
gets( p -> t , now ) ;
size[ now ] += size[ p -> t ] ;
}
}
void getrt( int now , int fa ) {
if ( root ) return ;
int ret = size[ rt ] / 2 ;
bool flag = ( size[ rt ] - size[ now ] ) <= ret ;
travel( now ) if ( p -> t != fa && ! del[ p -> t ] ) {
if ( size[ p -> t ] > ret ) flag = false ;
getrt( p -> t , now ) ;
}
if ( flag ) root = now ;
}
int h[ maxn ] , cnt[ maxn ] , mh , m , H[ maxn ] ;
void geth( int now , int fa ) {
if ( h[ now ] > mh ) mh = h[ now ] ;
travel( now ) if ( p -> t != fa && ! del[ p -> t ] ) {
h[ p -> t ] = h[ now ] + 1 ;
geth( p -> t , now ) ;
}
}
vector < int > sub[ maxn ] , tak[ maxn ] ;
void getsub( int now , int fa , int num ) {
sub[ num ].push_back( now ) , mh = max( mh , h[ now ] ) ;
travel( now ) if ( p -> t != fa && ! del[ p -> t ] ) getsub( p -> t , now , num ) ;
}
void solve( int now ) {
gets( now , 0 ) ;
root = 0 , rt = now ; getrt( now , 0 ) ;
h[ root ] = 0 ; geth( root , 0 ) ;
REP( i , 0 , mh ) tak[ i ].clear( ) , cnt[ i ] = 0 ;
int Mh = mh ;
cnt[ 0 ] = 1 , ans[ 0 ] ++ ;
travel( root ) if ( ! del[ p -> t ] ) {
sub[ p -> t ].clear( ) ;
mh = 0 ; getsub( p -> t , root , p -> t ) ;
tak[ mh ].push_back( p -> t ) ;
H[ p -> t ] = mh ;
}
REP( i , 0 , Mh ) Rep( j , tak[ i ].size( ) ) {
int x = tak[ i ][ j ] , m ;
for ( m = 1 ; m <= H[ x ] ; m <<= 1 ) ; m <<= 1 ;
Rep( k , m ) a[ k ] = b[ k ] = com( 0 , 0 ) ;
REP( k , 0 , H[ x ] ) a[ k ].a = double( cnt[ k ] ) ;
Rep( k , sub[ x ].size( ) ) b[ h[ sub[ x ][ k ] ] ].a += 1.0 ;
FFT( a , m , true ) , FFT( b , m , true ) ;
Rep( k , m ) c[ k ] = a[ k ] * b[ k ] ;
FFT( c , m , false ) ;
Rep( k , m ) ans[ k ] += int( c[ k ].a + 0.5 ) * 2 ;
Rep( k , sub[ x ].size( ) ) cnt[ h[ sub[ x ][ k ] ] ] ++ ;
}
del[ root ] = true ;
travel( root ) if ( ! del[ p -> t ] ) solve( p -> t ) ;
}
int main( ) {
memset( head , 0 , sizeof( head ) ) , memset( ans , 0 , sizeof( ans ) ) ;
scanf( "%d" , &n ) ;
REP( i , 2 , n ) {
int s , t ; scanf( "%d%d" , &s , &t ) ; addedge( ++ s , ++ t ) ;
}
memset( del , false , sizeof( del ) ) ;
solve( 1 ) ;
double Ans = 0 ;
REP( i , 0 , n ) Ans += double( ans[ i ] ) / double( i + 1 ) ;
printf( "%.4f\n" , Ans ) ;
return 0 ;
}