Do ta cần lấy $lcm$ của dãy $x$ và kiểm tra nó có chia hết cho $k$ hay không, với $k = \prod p_i^{x_i}$, ta chỉ cần quan tâm xem dãy $x$ có tồn tại phần tử nào chia hết cho các $p_i^{x_i}$ hay không mà thôi. Ở đây, số lượng số nguyên tố trong phân tích thừa số nguyên tố của $k$ sẽ bằng $m$ và $1 \le m \le 14$, do tích của $15$ số nguyên tố đầu tiên sẽ vượt quá $k$. Do đó ta nghĩ tới việc đặt $cnt(mask)$ là số lượng số trong đoạn từ $a$ tới $b$ chia hết cho tích các số trong $mask$, với $mask$ ở đây là một $bitmask$ - $bit$ $i$ tương đương $p_i^{x_i}$. Tuy nhiên như vậy, sẽ có khả năng tập $cnt(mask)$ đếm trùng các số chia hết cho cả các $mask'$ khác, để tránh được điều này, ta chỉ cần dùng **bao hàm loại trừ**. Lưu ý ở đây ta cần bao hàm loại trừ cho tất cả các $mask$, tức mất $3^m$. Ở code mẫu mình dùng $DP$ $SOS$ nên chỉ mất $2^m \times m$. Sau đó ta có hàm $f(mask)$ là số lượng số trong đoạn $[a,b]$ chia hết cho một **tập con** của $mask$. Bài toán này được gọi là "Sum over subset", độ phức tạp trâu nhất là $3^m$, nhưng ở đây mình dùng $DP$ $SOS$ để tối ưu. Các bạn có thể không cần dùng tới $DP$ $SOS$ vẫn có thể $AC$. Hàm $f(mask)$ đơn giản được tính bằng tổng các $dp(mask')$ với $mask'$ là tập con của $mask$. Tới đây, ta có $g(mask) = f(mask)^n$, tức là số dãy $x$ có $lcm$ là tập con của $mask$. Để biến $g(mask)$ thành số dãy $x$ có $lcm$ bằng $mask$, ta cần dùng **bao hàm loại trừ** thêm một lần nữa. Code: (DP SOS) ```cpp= #include<bits/stdc++.h> using namespace std; #define int long long #define ii pair<int,int> const int N = 230; const int M = 12; const int mod = 998244353; int f[1 << M]; int dp[1 << M]; int n,k,a,b,m; vector<int> p; int inv (int a) { //return -a; return (mod - a + mod)%mod; } void add (int &a, int b) { a += b; if (a >= mod) a -= mod; } void make (int k) { for (int i=2; i*i<=k; i++) { int cur = 1; while (k % i == 0) { cur*=i; k /= i; } if (cur != 1) p.push_back(cur); } if (k != 1) p.push_back(k); } void sos () { for (int mask=0; mask<(1 << m); mask++) { int cur = 1; for (int i=0; i<m; i++) { if (mask >> i & 1) cur *= p[i]; } dp[mask] = (b/cur - (a-1)/cur) % mod; // cout << mask << ' ' << dp[mask] << '\n'; if (__builtin_popcount(mask) % 2 == 0) dp[mask] = inv(dp[mask]); // cout << mask << ' ' << dp[mask] << '\n'; } int mx = (1 << m); for (int i=0; i<m; i++) { for (int mask=0; mask<mx; mask++) { if (mask >> i & 1) { add (dp[mask ^ (1 << i)], dp[mask]); } } } for (int mask=0; mask<mx; mask++) { if (__builtin_popcount(mask) % 2 == 0) dp[mask] = inv(dp[mask]); } } signed main() { ios_base::sync_with_stdio(false); cin.tie(0); cin >> n >> k >> a >> b; make(k); m = p.size(); sos(); int mx = 1 << m; for (int i=0; i<m; i++) { for (int mask=0; mask<mx; mask++) { if (mask >> i & 1) add (dp[mask],dp[mask^(1 << i)]); } } for (int mask=0; mask<mx; mask++) { f[mask] = 1; } for (int i=1; i<=n; i++){ for (int mask=0; mask < mx; mask++) { f[mask] = f[mask]*dp[mask]%mod; } } int mask = mx - 1; int res = 0; int dak = m % 2; for (int i = mask; true ; i = (i-1)&mask) { if (__builtin_popcount(i) % 2 == dak) add (res,f[i]); else { res = (res - f[i] + mod)%mod; } if (i == 0) break; } cout << res; } ``` Code: $3^m$: ```cpp= #include <iostream> #include <fstream> #include <vector> using namespace std; const int mod = 998244353; const int N = 1e7 + 5; vector <long long> p; int f[1 << 15]; int dp[1 << 15]; long long cnt[1 << 15]; int mask[N]; int bit(int x) { return (1 << x); } int getbit(int x, int i) { return (x >> i) & 1; } void add(int &a, int b) { a += b; if (a >= mod) a -= mod; } int main() { ios_base::sync_with_stdio(0); cin.tie(0); long long n, k, a, b; cin >> n >> k >> a >> b; for (int i = 2; (long long) i * i <= k; i++) { if (k % i == 0) { long long x = 1; while (k % i == 0) { k /= i; x *= i; } p.push_back(x); } } if (k > 1) p.push_back(k); int sz = p.size(); for (int mask = 0; mask < bit(sz); mask++) { long long x = 1; for (int j = 0; j < sz; j++) if (getbit(mask, j)) x *= p[j]; cnt[mask] = (b / x) - ((a - 1) / x); } for (int mask = bit(sz) - 1; mask > 0; mask--) { for (int s = (mask - 1) & mask;; s = (s - 1) & mask) { cnt[s] -= cnt[mask]; if (s == 0) break; } } for (int mask = 0; mask < bit(sz); mask++) cnt[mask] %= mod; for (int mask = 0; mask < bit(sz); mask++) { add(dp[mask], cnt[mask]); if (mask == 0) continue; for (int s = (mask - 1) & mask;; s = (s - 1) & mask) { add(dp[mask], cnt[s]); if (s == 0) break; } } for (int mask = 0; mask < bit(sz); mask++) { f[mask] = 1; for (int i = 1; i <= n; i++) f[mask] = (long long) f[mask] * dp[mask] % mod; } for (int mask = 1; mask < bit(sz); mask++) for (int s = (mask - 1) & mask;; s = (s - 1) & mask) { add(f[mask], mod - f[s]); if (s == 0) break; } cout << f[bit(sz) - 1]; return 0; } ```