题目大意
给定n个字符串,从这n个字符串中选择k个字符串组成一个新的字符串,然后给定一个新的字符串,然后输出该字符串的字典序在所有的新字符串的位置。
思路
可以先拿n个字符串排序,然后再建树,树结尾的那个字符的值设为排序后的编号,然后对新字符串查找,找到这几个的序号,到此分离字符串的工作完成了,然后就是计算排列了。树状数组维护的就是字典序小于某个字符串的数量。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mo = (int)1e9 + 7,
maxn = (int)1e6 + 10,
sigma_size = 27;
string str;
vector<int> vec;
vector<string> vs;
ll A[maxn];
struct Trie {
int ch[maxn][sigma_size];
int val[maxn];
int sz;
Trie() { sz = 1; memset(ch[0], 0, sizeof(ch[0])); }
int idx(char c) { return c - 'a'; }
void insert(string s, int v) {
int u = 0, n = s.length();
for(int i = 0; i < n; i++) {
int c = idx(s[i]);
if(!ch[u][c]) {
memset(ch[sz], 0, sizeof ch[sz]);
val[sz] = 0;
ch[u][c] = sz++;
}
u = ch[u][c];
}
val[u] = v;
}
void find(string s) {
int u = 0, len = s.length();
for (int i = 0; i < len; i++) {
int c = idx(s[i]);
u = ch[u][c];
if (val[u]) {
vec.push_back(val[u]);
u = 0;
}
}
}
} trie;
struct BIT {
ll C[maxn];
BIT() { memset(C, 0, sizeof C); }
int lowbit(int x) { return x&-x; }
ll sum(int x) {
ll ret = 0;
while(x > 0) {
ret = (ret + C[x]) % mo;
x -= lowbit(x);
}
return ret;
}
void add(int x, int d) {
while(x <= maxn) {
C[x] += d; x += lowbit(x);
}
}
} bit;
int main() {
//freopen("i.txt", "r", stdin);
ios::sync_with_stdio(false);
cin.tie(0);
int n, k;
cin >> n >> k;
for (int i = 0; i < n; i++) {
cin >> str;
vs.push_back(str);
}
sort(vs.begin(), vs.end());
for (int i = 0; i < n; i++) {
trie.insert(vs[i], i + 1);
bit.add(i + 1, 1);
}
cin >> str;
vec.clear();
trie.find(str);
ll ret = 0;
A[n-k] = 1;
for(int i = n-k+1; i <= n - 1; i++) {
A[i] = (A[i-1]*i) % mo; // 求排列数
}
for (int i = 0; i < k; i++) {
ret = (ret + bit.sum(vec[i] - 1) * A[n - i - 1]) % mo;
bit.add(vec[i], -1);
}
cout << (ret + 1) % mo << endl;
return 0;
}