# Codeforces Round 877 (Div. 2) (E) > https://codeforces.com/contest/1838/problem/E 給一個長度$n$的數列$a$,其中$1 \leq a_i \leq k$,問說有多少個長度為$m$的數列$b$,其中$1 \leq b_i \leq k$,使得$b$的子序列有數列$a$。$1 \leq n \leq 2 \times 10^5$,$n \leq m \leq 10^9$,$1 \leq k \leq 10^9$。 看到$m$的範圍想先土法煉鋼構造看看,假如我們枚舉$a_n$放在$b$中的位置,則我們可以寫出所求應為: \begin{array}{ll} \sum_{x=n}^{m}C^{x-1}_{n-1} \times (k-1)^{(x-1)-(n-1)} \times k^{m-x}, \end{array} 上式表示將$a$依序放入$b$且$a_n$之位置為$x$,在放入$a_{i-1}$後到放入$a_i$的位置之前不能放入$a_i$,所以在$a_n$之前每個空格都可放$k-1$個數,所以總共有$(k-1)^{(x-1)-(n-1)}$這麼多種可能,再來是放了$a_n$之後可以隨便放,所以有$k^{m-x}$這麼多種可能,然後然後$m$超大我就不會了==。 筆記一下解答的解法定義$dp(i,j)$為考慮長度為$i$的序列$b$且有最長a之前綴長度$j$,然後可以寫出: \begin{array}{ll} dp(i, j) = dp(i-1, j-1) + (k-1) \times dp(i-1, j),if \ (j \lt n), \\ dp(i, n) = dp(i-1, n-1) + k \times dp(i-1, n), \end{array} 這DP方程大概在$n,m$很小時很常見==,所求dp(m,n)直接是無法算的(至少我不會==)。 然後有個很妙的觀察是,這轉移方程跟數列$a$長怎樣根本無關,然後令個$a$都是1,又有想法了==,所求應為$m$中取$x$個位置為1,枚舉x得所求應為: \begin{array}{ll} \sum_{x = n}^{m} C^{m}_{x} \times (k-1)^{m-x}, \end{array} 想了一下發現從$x=0$時很好算==,然後再扣掉就好: \begin{array}{ll} \sum_{x = n}^{m} C^{m}_{x} \times (k-1)^{m-x} \\ = \sum_{x = 0}^{m} C^{m}_{x} \times (k-1)^{m-x} - \sum_{x = 0}^{n-1} C^{m}_{x} \times (k-1)^{m-x} \\ = (1 + (k-1))^{m} - \sum_{x = 0}^{n-1} C^{m}_{x} \times (k-1)^{m-x}, \end{array} 然後注意一下因為$m$很大,二項式係數要直接算。 ```cpp= #include <bits/stdc++.h> #pragma GCC optimize(3) #define ll long long #define pii pair<int, int> #define pll pair<long long, long long> #define F first #define S second #define endl '\n' using namespace std; const int inf = 0x3f3f3f3f; const ll Inf = 1e18; const int mod = 1e9 + 7; const int N = 2e5 + 5; ll qmul(ll x, ll y, int p = mod){ ll cur = x, ans = 1; while(y > 0){ if(y & 1) ans = (ans * cur) % p; cur = (cur * cur) % p; y >>= 1; } return ans; } const int maxn = 1e6 + 5; ll fac[maxn], inv[maxn]; void init(){ fac[0] = fac[1] = inv[0] = inv[1] = 1; for(int i=2; i<maxn; i++){ fac[i] = (i * fac[i-1]) % mod; inv[i] = mod - mod/i * inv[mod%i] % mod; } // for(int i=2; i<maxn; i++){ // inv[i] = (inv[i] * inv[i-1]) % mod; // } return; } int n, m, k; ll a[N]; void solve(){ cin >> n >> m >> k; for(int i=1; i<=n; i++) cin >> a[i]; if(m < n){ cout << 0 << endl; return; } if(k == 1 || m == n){ cout << 1 << endl; return; } ll ans = qmul(k, m); ll tmp = qmul(k-1, m); ans = (ans - tmp) % mod; ll invk_1 = qmul(k-1, mod - 2); for(int i=1; i<=n-1; i++){ tmp = (tmp * (m - i + 1) % mod * inv[i] % mod * invk_1) % mod; ans = (ans - tmp) % mod; } if(ans < 0) ans += mod; cout << ans << endl; } int main(){ ios::sync_with_stdio(0); cin.tie(0); freopen("input.txt", "r", stdin); freopen("output.txt", "w", stdout); init(); int t = 1; cin >> t; while(t--) solve(); return 0; } ``` ### 時間複雜度 :$O(nlog(mod))$ ###### tags: `combinatorics` `math`