Do ta cần lấy $lcm$ của dãy $x$ và kiểm tra nó có chia hết cho $k$ hay không, với $k = \prod p_i^{x_i}$, ta chỉ cần quan tâm xem dãy $x$ có tồn tại phần tử nào chia hết cho các $p_i^{x_i}$ hay không mà thôi.
Ở đây, số lượng số nguyên tố trong phân tích thừa số nguyên tố của $k$ sẽ bằng $m$ và $1 \le m \le 14$, do tích của $15$ số nguyên tố đầu tiên sẽ vượt quá $k$.
Do đó ta nghĩ tới việc đặt $cnt(mask)$ là số lượng số trong đoạn từ $a$ tới $b$ chia hết cho tích các số trong $mask$, với $mask$ ở đây là một $bitmask$ - $bit$ $i$ tương đương $p_i^{x_i}$.
Tuy nhiên như vậy, sẽ có khả năng tập $cnt(mask)$ đếm trùng các số chia hết cho cả các $mask'$ khác, để tránh được điều này, ta chỉ cần dùng **bao hàm loại trừ**.
Lưu ý ở đây ta cần bao hàm loại trừ cho tất cả các $mask$, tức mất $3^m$. Ở code mẫu mình dùng $DP$ $SOS$ nên chỉ mất $2^m \times m$.
Sau đó ta có hàm $f(mask)$ là số lượng số trong đoạn $[a,b]$ chia hết cho một **tập con** của $mask$. Bài toán này được gọi là "Sum over subset", độ phức tạp trâu nhất là $3^m$, nhưng ở đây mình dùng $DP$ $SOS$ để tối ưu. Các bạn có thể không cần dùng tới $DP$ $SOS$ vẫn có thể $AC$. Hàm $f(mask)$ đơn giản được tính bằng tổng các $dp(mask')$ với $mask'$ là tập con của $mask$.
Tới đây, ta có $g(mask) = f(mask)^n$, tức là số dãy $x$ có $lcm$ là tập con của $mask$.
Để biến $g(mask)$ thành số dãy $x$ có $lcm$ bằng $mask$, ta cần dùng **bao hàm loại trừ** thêm một lần nữa.
Code: (DP SOS)
```cpp=
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define ii pair<int,int>
const int N = 230;
const int M = 12;
const int mod = 998244353;
int f[1 << M];
int dp[1 << M];
int n,k,a,b,m;
vector<int> p;
int inv (int a) {
//return -a;
return (mod - a + mod)%mod;
}
void add (int &a, int b) {
a += b;
if (a >= mod) a -= mod;
}
void make (int k) {
for (int i=2; i*i<=k; i++) {
int cur = 1;
while (k % i == 0) {
cur*=i;
k /= i;
}
if (cur != 1) p.push_back(cur);
}
if (k != 1) p.push_back(k);
}
void sos () {
for (int mask=0; mask<(1 << m); mask++) {
int cur = 1;
for (int i=0; i<m; i++) {
if (mask >> i & 1) cur *= p[i];
}
dp[mask] = (b/cur - (a-1)/cur) % mod;
// cout << mask << ' ' << dp[mask] << '\n';
if (__builtin_popcount(mask) % 2 == 0) dp[mask] = inv(dp[mask]);
// cout << mask << ' ' << dp[mask] << '\n';
}
int mx = (1 << m);
for (int i=0; i<m; i++) {
for (int mask=0; mask<mx; mask++) {
if (mask >> i & 1) {
add (dp[mask ^ (1 << i)], dp[mask]);
}
}
}
for (int mask=0; mask<mx; mask++) {
if (__builtin_popcount(mask) % 2 == 0) dp[mask] = inv(dp[mask]);
}
}
signed main() {
ios_base::sync_with_stdio(false);
cin.tie(0);
cin >> n >> k >> a >> b;
make(k);
m = p.size();
sos();
int mx = 1 << m;
for (int i=0; i<m; i++) {
for (int mask=0; mask<mx; mask++) {
if (mask >> i & 1) add (dp[mask],dp[mask^(1 << i)]);
}
}
for (int mask=0; mask<mx; mask++) {
f[mask] = 1;
}
for (int i=1; i<=n; i++){
for (int mask=0; mask < mx; mask++) {
f[mask] = f[mask]*dp[mask]%mod;
}
}
int mask = mx - 1;
int res = 0;
int dak = m % 2;
for (int i = mask; true ; i = (i-1)&mask) {
if (__builtin_popcount(i) % 2 == dak) add (res,f[i]);
else {
res = (res - f[i] + mod)%mod;
}
if (i == 0) break;
}
cout << res;
}
```
Code: $3^m$:
```cpp=
#include <iostream>
#include <fstream>
#include <vector>
using namespace std;
const int mod = 998244353;
const int N = 1e7 + 5;
vector <long long> p;
int f[1 << 15];
int dp[1 << 15];
long long cnt[1 << 15];
int mask[N];
int bit(int x) {
return (1 << x);
}
int getbit(int x, int i) {
return (x >> i) & 1;
}
void add(int &a, int b) {
a += b;
if (a >= mod)
a -= mod;
}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
long long n, k, a, b;
cin >> n >> k >> a >> b;
for (int i = 2; (long long) i * i <= k; i++) {
if (k % i == 0) {
long long x = 1;
while (k % i == 0) {
k /= i;
x *= i;
}
p.push_back(x);
}
}
if (k > 1)
p.push_back(k);
int sz = p.size();
for (int mask = 0; mask < bit(sz); mask++) {
long long x = 1;
for (int j = 0; j < sz; j++)
if (getbit(mask, j))
x *= p[j];
cnt[mask] = (b / x) - ((a - 1) / x);
}
for (int mask = bit(sz) - 1; mask > 0; mask--) {
for (int s = (mask - 1) & mask;; s = (s - 1) & mask) {
cnt[s] -= cnt[mask];
if (s == 0)
break;
}
}
for (int mask = 0; mask < bit(sz); mask++)
cnt[mask] %= mod;
for (int mask = 0; mask < bit(sz); mask++) {
add(dp[mask], cnt[mask]);
if (mask == 0)
continue;
for (int s = (mask - 1) & mask;; s = (s - 1) & mask) {
add(dp[mask], cnt[s]);
if (s == 0)
break;
}
}
for (int mask = 0; mask < bit(sz); mask++) {
f[mask] = 1;
for (int i = 1; i <= n; i++)
f[mask] = (long long) f[mask] * dp[mask] % mod;
}
for (int mask = 1; mask < bit(sz); mask++)
for (int s = (mask - 1) & mask;; s = (s - 1) & mask) {
add(f[mask], mod - f[s]);
if (s == 0)
break;
}
cout << f[bit(sz) - 1];
return 0;
}
```