# Codeforces Round 877 (Div. 2) (E)
> https://codeforces.com/contest/1838/problem/E
給一個長度$n$的數列$a$,其中$1 \leq a_i \leq k$,問說有多少個長度為$m$的數列$b$,其中$1 \leq b_i \leq k$,使得$b$的子序列有數列$a$。$1 \leq n \leq 2 \times 10^5$,$n \leq m \leq 10^9$,$1 \leq k \leq 10^9$。
看到$m$的範圍想先土法煉鋼構造看看,假如我們枚舉$a_n$放在$b$中的位置,則我們可以寫出所求應為:
\begin{array}{ll}
\sum_{x=n}^{m}C^{x-1}_{n-1} \times (k-1)^{(x-1)-(n-1)} \times k^{m-x},
\end{array}
上式表示將$a$依序放入$b$且$a_n$之位置為$x$,在放入$a_{i-1}$後到放入$a_i$的位置之前不能放入$a_i$,所以在$a_n$之前每個空格都可放$k-1$個數,所以總共有$(k-1)^{(x-1)-(n-1)}$這麼多種可能,再來是放了$a_n$之後可以隨便放,所以有$k^{m-x}$這麼多種可能,然後然後$m$超大我就不會了==。
筆記一下解答的解法定義$dp(i,j)$為考慮長度為$i$的序列$b$且有最長a之前綴長度$j$,然後可以寫出:
\begin{array}{ll}
dp(i, j) = dp(i-1, j-1) + (k-1) \times dp(i-1, j),if \ (j \lt n), \\
dp(i, n) = dp(i-1, n-1) + k \times dp(i-1, n),
\end{array}
這DP方程大概在$n,m$很小時很常見==,所求dp(m,n)直接是無法算的(至少我不會==)。
然後有個很妙的觀察是,這轉移方程跟數列$a$長怎樣根本無關,然後令個$a$都是1,又有想法了==,所求應為$m$中取$x$個位置為1,枚舉x得所求應為:
\begin{array}{ll}
\sum_{x = n}^{m} C^{m}_{x} \times (k-1)^{m-x},
\end{array}
想了一下發現從$x=0$時很好算==,然後再扣掉就好:
\begin{array}{ll}
\sum_{x = n}^{m} C^{m}_{x} \times (k-1)^{m-x} \\
= \sum_{x = 0}^{m} C^{m}_{x} \times (k-1)^{m-x} - \sum_{x = 0}^{n-1} C^{m}_{x} \times (k-1)^{m-x}
\\
= (1 + (k-1))^{m} - \sum_{x = 0}^{n-1} C^{m}_{x} \times (k-1)^{m-x},
\end{array}
然後注意一下因為$m$很大,二項式係數要直接算。
```cpp=
#include <bits/stdc++.h>
#pragma GCC optimize(3)
#define ll long long
#define pii pair<int, int>
#define pll pair<long long, long long>
#define F first
#define S second
#define endl '\n'
using namespace std;
const int inf = 0x3f3f3f3f;
const ll Inf = 1e18;
const int mod = 1e9 + 7;
const int N = 2e5 + 5;
ll qmul(ll x, ll y, int p = mod){
ll cur = x, ans = 1;
while(y > 0){
if(y & 1) ans = (ans * cur) % p;
cur = (cur * cur) % p;
y >>= 1;
}
return ans;
}
const int maxn = 1e6 + 5;
ll fac[maxn], inv[maxn];
void init(){
fac[0] = fac[1] = inv[0] = inv[1] = 1;
for(int i=2; i<maxn; i++){
fac[i] = (i * fac[i-1]) % mod;
inv[i] = mod - mod/i * inv[mod%i] % mod;
}
// for(int i=2; i<maxn; i++){
// inv[i] = (inv[i] * inv[i-1]) % mod;
// }
return;
}
int n, m, k;
ll a[N];
void solve(){
cin >> n >> m >> k;
for(int i=1; i<=n; i++) cin >> a[i];
if(m < n){
cout << 0 << endl;
return;
}
if(k == 1 || m == n){
cout << 1 << endl;
return;
}
ll ans = qmul(k, m);
ll tmp = qmul(k-1, m);
ans = (ans - tmp) % mod;
ll invk_1 = qmul(k-1, mod - 2);
for(int i=1; i<=n-1; i++){
tmp = (tmp * (m - i + 1) % mod * inv[i] % mod * invk_1) % mod;
ans = (ans - tmp) % mod;
}
if(ans < 0) ans += mod;
cout << ans << endl;
}
int main(){
ios::sync_with_stdio(0); cin.tie(0);
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
init();
int t = 1;
cin >> t;
while(t--)
solve();
return 0;
}
```
### 時間複雜度 :$O(nlog(mod))$
###### tags: `combinatorics` `math`