题目:http://www.lydsy.com/JudgeOnline/problem.php?id=2243
轻重树链剖分之后线段树维护就好了。
代码:
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
using namespace std ;
#define left( t ) ( t << 1 )
#define right( t ) ( left( t ) ^ 1 )
#define MAXN 101000
#define AddEdge( s , t ) Add( s , t ) , Add( t , s )
#define MAXB 20
struct edge {
int t ;
edge *next ;
} *head[ MAXN ] ;
void Add( int s , int t ) {
edge *p = new( edge ) ;
p -> t = t , p -> next = head[ s ] ;
head[ s ] = p ;
}
bool f[ MAXN ] ;
int size[ MAXN ] , pchild[ MAXN ] , arr[ MAXN ] , fir[ MAXN ] , h[ MAXN ] ;
int up[ MAXN ][ MAXB ] , B , Index = 0 ;
int pos[ MAXN ] ;
void dfs0( int v ) {
f[ v ] = false , size[ v ] = 1 ;
int rec = 0 ;
for ( edge *p = head[ v ] ; p ; p = p -> next ) if ( f[ p -> t ] ) {
h[ p -> t ] = h[ v ] + 1 , up[ p -> t ][ 0 ] = v ;
dfs0( p -> t ) ;
if ( size[ p -> t ] > rec ) {
rec = size[ p -> t ] , pchild[ v ] = p -> t ;
}
size[ v ] += size[ p -> t ] ;
}
}
void dfs1( int v , int u ) {
f[ v ] = false , fir[ v ] = u , arr[ pos[ v ] = ++ Index ] = v ;
if ( pchild[ v ] ) {
dfs1( pchild[ v ] , u ) ;
for ( edge *p = head[ v ] ; p ; p = p -> next ) if ( f[ p -> t ] ) {
dfs1( p -> t , p -> t ) ;
}
}
}
struct Color {
int num , lcol , rcol ;
Color( ) {
num = 0 ;
}
void SWAP( ) {
swap( lcol , rcol ) ;
}
void Change( int col ) {
num = 1 , lcol = rcol = col ;
}
};
Color operator + ( const Color &x , const Color &y ) {
if ( ! x.num ) return y ;
if ( ! y.num ) return x ;
Color ret ;
ret.num = x.num + y.num - ( x.rcol == y.lcol ) ;
ret.lcol = x.lcol , ret.rcol = y.rcol ;
return ret ;
}
struct node {
Color c ;
int l , r ;
int flag ;
node( ) {
flag = - 1 ;
}
} sgt[ MAXN << 2 ] ;
int n , m , C[ MAXN ] ;
void pushdown( int t ) {
if ( sgt[ t ].flag != - 1 ) {
sgt[ t ].c.Change( sgt[ t ].flag ) ;
if ( sgt[ t ].l < sgt[ t ].r ) {
sgt[ left( t ) ].flag = sgt[ t ].flag ;
sgt[ right( t ) ].flag = sgt[ t ].flag ;
}
sgt[ t ].flag = - 1 ;
}
}
void update( int t ) {
pushdown( t ) ;
if ( sgt[ t ].l < sgt[ t ].r ) {
pushdown( left( t ) ) , pushdown( right( t ) ) ;
sgt[ t ].c = sgt[ left( t ) ].c + sgt[ right( t ) ].c ;
}
}
void change( int l , int r , int col , int t ) {
if ( l == sgt[ t ].l && r == sgt[ t ].r ) {
sgt[ t ].flag = col ; pushdown( t ) ; return ;
}
pushdown( t ) ;
int mid = ( sgt[ t ].l + sgt[ t ].r ) >> 1 ;
if ( r <= mid ) change( l , r , col , left( t ) ) ; else
if ( l > mid ) change( l , r , col , right( t ) ) ; else {
change( l , mid , col , left( t ) ) ;
change( mid + 1 , r , col , right( t ) ) ;
}
update( t ) ;
}
void Build( int l , int r , int t ) {
sgt[ t ].l = l , sgt[ t ].r = r ;
if ( l == r ) {
sgt[ t ].c.Change( C[ arr[ l ] ] ) ; return ;
}
int mid = ( l + r ) >> 1 ;
Build( l , mid , left( t ) ) , Build( mid + 1 , r , right( t ) ) ;
update( t ) ;
}
Color query( int l , int r , int t ) {
pushdown( t ) ;
if ( l == sgt[ t ].l && r == sgt[ t ].r ) return sgt[ t ].c ;
int mid = ( sgt[ t ].l + sgt[ t ].r ) >> 1 ;
if ( r <= mid ) return query( l , r , left( t ) ) ;
if ( l > mid ) return query( l , r , right( t ) ) ;
return query( l , mid , left( t ) ) + query( mid + 1 , r , right( t ) ) ;
}
void Init( ) {
memset( head , 0 , sizeof( head ) ) ;
scanf( "%d%d" , &n , &m ) ;
for ( int i = 0 ; i ++ < n ; ) scanf( "%d" , &C[ i ] ) ;
for ( int i = 0 ; i ++ < n - 1 ; ) {
int s , t ; scanf( "%d%d" , &s , &t ) ;
AddEdge( s , t ) ;
}
memset( f , true , sizeof( f ) ) ;
memset( h , 0 , sizeof( h ) ) ;
memset( up , 0 , sizeof( up ) ) ;
memset( pchild , 0 , sizeof( pchild ) ) ;
h[ 1 ] = 1 ;
dfs0( 1 ) ;
memset( f , true , sizeof( f ) ) ;
dfs1( 1 , 1 ) ;
B = int( log2( n ) ) + 1 ;
for ( int i = 0 ; i ++ < B ; ) {
for ( int j = 0 ; j ++ < n ; ) {
up[ j ][ i ] = up[ up[ j ][ i - 1 ] ][ i - 1 ] ;
}
}
}
int Lca( int u , int v ) {
if ( h[ u ] < h[ v ] ) swap( u , v ) ;
for ( int i = B ; i >= 0 ; -- i ) {
if ( h[ up[ u ][ i ] ] >= h[ v ] ) u = up[ u ][ i ] ;
}
if ( u == v ) return u ;
for ( int i = B ; i >= 0 ; -- i ) {
if ( up[ u ][ i ] != up[ v ][ i ] ) {
u = up[ u ][ i ] , v = up[ v ][ i ] ;
}
}
return up[ u ][ 0 ] ;
}
Color Query( int v , int lca ) {
Color ret ;
while ( h[ v ] >= h[ lca ] ) {
if ( h[ fir[ v ] ] > h[ lca ] ) {
ret = query( pos[ fir[ v ] ] , pos[ v ] , 1 ) + ret ;
v = up[ fir[ v ] ][ 0 ] ;
} else {
ret = query( pos[ lca ] , pos[ v ] , 1 ) + ret ;
break ;
}
}
return ret ;
}
void Change( int v , int lca , int col ) {
while ( h[ v ] >= h[ lca ] ) {
if ( h[ fir[ v ] ] > h[ lca ] ) {
change( pos[ fir[ v ] ] , pos[ v ] , col , 1 ) ;
v = up[ fir[ v ] ][ 0 ] ;
} else {
change( pos[ lca ] , pos[ v ] , col , 1 ) ; break ;
}
}
}
void Solve( ) {
Build( 1 , n , 1 ) ;
while ( m -- ) {
int ch ; for ( ch = getchar( ) ; ch != 'Q' && ch != 'C' ; ch = getchar( ) ) ;
if ( ch == 'Q' ) {
int u , v , lca ; scanf( "%d%d" , &u , &v ) ;
lca = Lca( u , v ) ;
Color rec = Query( u , lca ) , ret = Query( v , lca ) ;
printf( "%d\n" , rec.num + ret.num - 1 ) ;
} else {
int u , v , col , lca ; scanf( "%d%d%d" , &u , &v , &col ) ;
lca = Lca( u , v ) ;
Change( u , lca , col ) , Change( v , lca , col ) ;
}
}
}
int main( ) {
Init( ) ;
Solve( ) ;
return 0 ;
}