Trie树总结(codeforces)

Trie树总结

理解

     试想一种情形,给出一堆字符串,将其根据前缀是否相同而将其组成一颗树,维护两个条件,一是对这颗树的所有结点进行标记,二是在每个字符串结束时做标记.

数组写法

struct Trie
{
    int ch[ maxn ][ 26 ]; //结点
    int val[ maxn ];
    int tot;   // 结点标记

    void Init()
    {
        memset( ch, 0, sizeof( ch ) );
        memset( val, 0, sizeof( val ) );
        tot = 1;
    }

    void Insert( string s )
    {
        int root = 0;
        for( auto x: s )
        {
            int Index = x - 'a';
            if( !ch[ root ][ Index ] ) ch[ root ][ Index ] = tot ++;
            root = ch[ root ][ Index ];
        }
        val[ root ] = 1;
    }
    
    int Search( string s )
    {
        int root = 0;
        for( auto x: s )
        {
            int Index = x - 'a';
            if( !ch[ root ][ Index ] )  return 0;
            root = ch[ root ][ Index ];
        }
        return val[ root ];
    }
};

数组写法特点

1.需要分析maxn的取值,最坏情况就是全部散开,即 n * maxwordslength

2.数组写法不适合leetcode

题目列表

  • 557E Ann and Half-Palindrome
  • 633C Spy Syndrome 2
  • 665E Beautiful Subarrays

557E Ann and Half-Palindrome

题意

定义了半回文串,即奇数位置上的数是回文的,给出字符串,求其第k大的半回文子串

题解

首先求出子串的半回文串条件
dp[ i ][ j ] 表示s[ i – j ]中是半回文子串的个数
o(n^2)可以求出.在求出过程中不能直接把半回文子串插入,否则超时.
然后把s的全部子串都插入到Trie中.在Trie中记录当前结点有多少个半回文子串经过
然后dfs求出满足条件的子串
在dfs时,由父串至子串过程中,减去满足在父串停留的满足条件的字符串个数.

代码
#include<bits/stdc++.h>
using namespace std;

const int maxn = 5010;
const int Max  = 2e7;

int dp[ maxn ][ maxn ];
string s;
int k, n;

struct Trie
{
    int ch[ Max ][ 2 ];
    int num[ Max ];
    int tot;

    void Init()
    {
        memset( ch, 0, sizeof( ch ) );
        memset( num, 0, sizeof( num ) );
        tot = 1;
    }

    void Insert( int k )
    {
        int root = 0;
        for( int i = k; i < n; ++ i )
        {
            int Index = s[ i ] - 'a';
            if( !ch[ root ][ Index ] ) ch[ root ][ Index ] = tot ++;
            root = ch[ root ][ Index ];
            if( k == i ) num[ root ] += dp[ k ][ n - 1 ];
            else         num[ root ] += ( dp[ k ][ n - 1 ] - dp[ k ][ i - 1 ] );
        }
    }

    void Dfs( int root, int k )
    {
        if( k <= 0 ) return;
        int num_a = ch[ root ][ 0 ];
        int num_b = ch[ root ][ 1 ];
        if( num_a && num[ num_a ] >= k )
        {
            cout << "a";
            if( ch[ num_a ][ 0 ] )  num[ num_a ] -= num[ ch[ num_a ][ 0 ] ];
            if( ch[ num_a ][ 1 ] )  num[ num_a ] -= num[ ch[ num_a ][ 1 ] ];
            Dfs( num_a, k - num[ num_a ] );
        }
        else
        {
            if( num_a ) k -= num[ num_a ];
            cout << "b";
            if( ch[ num_b ][ 0 ] )  num[ num_b ] -= num[ ch[ num_b ][ 0 ] ];
            if( ch[ num_b ][ 1 ] )  num[ num_b ] -= num[ ch[ num_b ][ 1 ] ];
            Dfs( num_b, k - num[ num_b ] );
        }
    }

};

Trie T;

int main()
{
    cin >> s;
    cin >> k;
    n = s.length();
    for( int i = 0; i < n; ++ i ) dp[ i ][ i ] = 1;
    for( int i = 0; i < n - 1; ++ i ) dp[ i ][ i + 1 ] = ( s[ i ] == s[ i + 1 ] );
    for( int i = 2; i < n; ++ i )
    for( int j = 0; j + i < n; ++ j )
    {
        if( s[ j ] == s[ j + i ] ) dp[ j ][ j + i ] = ( j + 2 >= j + i - 2 ? 1 : dp[ j + 2 ][ j + i - 2 ] );
    }
    for( int i = 1; i < n; ++ i )
    for( int j = 0; j < n - 1; ++ j ) dp[ j ][ j + i ] += dp[ j ][ j + i - 1 ];

    T.Init();
    for( int i = 0; i < n; ++ i ) T.Insert( i );
    T.Dfs( 0, k );
    return 0;
}

633C Spy Syndrome 2

题意

一堆单词,把这一堆单词都翻转了,拼成了一个串。
然后现在给你一个串,让你找到原来拼的那些单词是什么。

题解

先把当前串s翻转,然后将所有串插入Trie中
数组vis[ i ]标记标记当前s[i]是否到达,然后记录前缀.因为任意一个满足条件的均可以. 所以从前往后更新vis即可.

代码
#include<bits/stdc++.h>
using namespace std;

const int maxn = 1e6 + 10;
const int Max = 1e5 + 10;
int pre[ Max ], Vis[ Max ], Mappre[ Max ];
map< int, string > Map;
vector< string > goal;
string s, tmp;
int n, m;

struct Trie
{
    int ch[ maxn ][ 26 ];
    int val[ maxn ];
    int tot;

    void Init()
    {
        memset( ch, 0, sizeof( ch ) );
        memset( val, 0, sizeof( val ) );
        tot = 1;
    }

    void Insert( string word, int Val )
    {
        int root = 0;
        for( auto x: word )
        {
            int Index = 0;
            if( x >= 'a' && x <= 'z' ) Index = x - 'a';
            else                       Index = x - 'A';
            if( !ch[ root ][ Index ] ) ch[ root ][ Index ] = tot ++;
            root = ch[ root ][ Index ];
        }
        val[ root ] = Val;
    }

    void Search( int now )
    {
        int root = 0;
        for( int i = now; i < n; ++ i )
        {
            int Index = s[ i ] - 'a';
            if( !ch[ root ][ Index ] ) return;
            else
            {
                root = ch[ root ][ Index ];
                if( val[ root ] )
                {
                    Vis[ i ] = 1;
                    pre[ i ] = now;
                    Mappre[ i ] = val[ root ];
                }
            }
        }
    }

};

Trie T;

int main()
{
    cin >> n;
    cin >> s;
    cin >> m;
    T.Init();
    for( int i = 1; i <= m; ++ i )
    {
        cin >> tmp;
        Map[ i ] = tmp;
        T.Insert( tmp, i );

    }
    reverse( s.begin(), s.end() );
    T.Search( 0 );
    for( int i = 0; i < n; ++ i ) if( Vis[ i ] ) T.Search( i + 1 );
    int num = n - 1;
    while( num != -1 )
    {
        goal.push_back( Map[ Mappre[ num ] ] );
        num = pre[ num ] - 1;
    }
    for( auto x: goal ) cout << x << " ";
    return 0;
}

665E Beautiful Subarrays

题意

问你有多少个区间,异或起来大于等于k

题解

pre[x]表示0~x的前缀和
将题意转换为有多少对( x, y )满足 0<= x < y < n, 使得 pre[ x ] ^pre[y] >= k
Trie存储从高位到地位的01,每次将pre[x]放入后,在Trie中遍历pre[ x + 1 ],遇到k的某个二进制为0时,添加和即可.

代码
#include<bits/stdc++.h>
using namespace std;

const int maxn = 2e7 + 10;

struct Trie
{
    int  ch[ maxn ][ 2 ];
    int  tot_num[ maxn ];
    int  sz;

    void Init()
    {
        memset( ch, 0, sizeof( ch ) );
        memset( tot_num, 0, sizeof( tot_num ) );
        sz = 2;
    }

    void Insert( int num )
    {
        int cnt = 1;
        for( int i = 30; i >= 0; -- i )
        {
            int Index = ( num >> i ) & 1;
            if( !ch[ cnt ][ Index ] )  ch[ cnt ][ Index ] = sz ++;
            tot_num[ cnt ] ++;
            cnt = ch[ cnt ][ Index ];
        }
        tot_num[ cnt ] ++;
    }

    long long Get( int num, int k )
    {
        int cnt = 1;
        long long sum = 0;
        for( int i = 30; i >= 0; -- i )
        {
            int Index_num = ( num >> i ) & 1 ^ 1;
            int Index_k = ( k >> i ) & 1;
            if( Index_k == 0 )
            {
                sum +=tot_num[  ch[ cnt ][ Index_num ] ];
                cnt = ch[ cnt ][ Index_num ^ 1 ];
            }
            else cnt = ch[ cnt ][ Index_num ];
        }
        return sum + tot_num[ cnt ];
    }
};

Trie T;

int main()
{
    int n, k, x;
    cin >> n >> k;
    int pre = 0;
    long long ans = 0;
    T.Init();
    T.Insert( 0 );
    for( int i = 0; i < n; ++ i )
    {
        cin >> x;
        pre ^= x;
        ans += T.Get( pre, k );
        T.Insert( pre );
    }
    cout << ans << endl;
    return 0;
}
    原文作者:Trie树
    原文地址: https://blog.csdn.net/qq_27567837/article/details/83515794
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞