题目大意:
就是现在给出至多10^4个字符串每个长度都在1~40之间, 只包含小写字母, 问如果将其中任意一个串的前缀或者是任意一个串的后缀连接起来可以构成一个新词, 那么包括这些词本身在内一共可以形成多少个不同的词
大致思路:
这个题感觉还是挺巧妙地利用了Trie树来计数, 首先我们将所有的n个串插入到一个Trie树中, 然后将所有串倒过来插入到另外一个Trie书中, 那么trie1中的节点数 – 1就是非空的不同的前缀的数量, trie2 中节点数-1就是不同的后缀数量, 设节点数分别为L1, L2, 那么L1*L2就是形式为 S1+S2的字符串数量, S1是前缀, S2是后缀, 都不为空
那么现在考虑重复的情况, 这里我们的S1和S2不空所以没有计算单个字符的单词的情况, 对于单个字符在输入的时候判断一下就好了
那么处理S1+S2的个数重复问题, 首先如果S1+S2长度是2的话是没有计算重复的, 否则的话一个形似S3 + c + S4的串, 在S1 = S3, S2 = c + S4和S1 = S3 + c, S2 = S4时被计算重复
那么就是说每一对在不同Trie树中的相同字符会导致一次重复计数(这个重复只能是相同字符出现在Trie树中第2层以下, 根是第0层), 于是减去重复的部分就行了
代码如下:
Result : Accepted Memory : ? KB Time : 222 ms
/*
* Author: Gatevin
* Created Time: 2015/5/6 12:01:55
* File Name: Rin_Tohsaka.cpp
*/
#include<iostream>
#include<sstream>
#include<fstream>
#include<vector>
#include<list>
#include<deque>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<bitset>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cmath>
#include<ctime>
#include<iomanip>
using namespace std;
const double eps(1e-8);
typedef long long lint;
#define foreach(e, x) for(__typeof(x.begin()) e = x.begin(); e != x.end(); ++e)
#define SHOW_MEMORY(x) cout<<sizeof(x)/(1024*1024.)<<"MB"<<endl
struct Trie
{
int L, root;
int next[400001][26];
int cnt[26];
int newnode()
{
for(int i = 0; i < 26; i++)
next[L][i] = -1;
return L++;
}
void init()
{
L = 0;
root = newnode();
memset(cnt, 0, sizeof(cnt));
}
void insert(char *s, int len)
{
int now = root;
for(int i = 0; i < len; i++)
{
if(next[now][s[i] - 'a'] == -1)
{
next[now][s[i] - 'a'] = newnode();
if(i) cnt[s[i] - 'a']++;//第2层和以上
}
now = next[now][s[i] - 'a'];
}
}
};
Trie trie1, trie2;
char s[50];
bool vis[26];
int main()
{
int n;
while(scanf("%d", &n) != EOF)
{
trie1.init(), trie2.init();
memset(vis, 0, sizeof(vis));
lint ans = 0;
while(n--)
{
scanf("%s", s);
int len = strlen(s);
if(len == 1 && !vis[s[0] - 'a'])//单个字符的
vis[s[0] - 'a'] = 1, ans++;
trie1.insert(s, len);
reverse(s, s + len);
trie2.insert(s, len);
}
ans += (lint)(trie1.L - 1)*(lint)(trie2.L - 1);//S1+S2形式的串(有重复)
for(int i = 0; i < 26; i++)
ans -= (lint)trie1.cnt[i]*(lint)trie2.cnt[i];//去重
printf("%lld\n", ans);
}
return 0;
}