1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
| const ll mod = 998244353; const ll G = 3; const ll Gi = 332748118; ll n,m,a[N]; ll jc[N],inv[N],lim,len,rev[N];
ll ffpow(ll a,ll b) { ll ans = 1; for (;b;b >>= 1) { if (b & 1) ans = ans * a % mod; a = a * a % mod; } return ans; }
ll C(int n,int m) { if (n > m || n < 0) return 0; return jc[m] * inv[n] % mod * inv[m - n] % mod; }
void NTT(ll *A,int typ) { for (int i = 0;i < lim;++i) if (i < rev[i]) swap(A[i],A[rev[i]]); for (int mid = 1;mid < lim;mid <<= 1) { ll gn = ffpow(typ==1?G:Gi,(mod - 1) / (mid << 1)); for (int j = 0;j < lim;j += (mid << 1)) { ll w = 1; for (int k = 0;k < mid;++k,w = (w * gn) % mod) { ll x = A[j + k],y = w * A[j + k + mid] % mod; A[j + k] = (x + y) % mod; A[j + k + mid] = (x - y + mod) % mod; } } } }
vector <ll> tmp[N]; ll f[N],g[N]; vector <ll> CDQNTT(int l,int r) { if (l == r) { return tmp[l]; } int mid = (l + r) >> 1; vector <ll> L = CDQNTT(l,mid),R= CDQNTT(mid + 1,r); int l1 = L.size(),l2 = R.size(); lim = 1,len = 0; while (lim < l1+l2) lim <<= 1,++len; for (int i = 0;i < lim;++i) rev[i] = (rev[i>>1] >> 1) | ((i & 1) << (len - 1)); for (int i = 0;i < l1;++i) f[i] = L[i]; for (int i = l1;i < lim;++i) f[i] = 0; for (int i = 0;i < l2;++i) g[i] = R[i]; for (int i = l2;i < lim;++i) g[i] = 0; NTT(f,1);NTT(g,1); for (int i = 0;i < lim;++i) f[i] = f[i] * g[i] % mod; NTT(f,-1); ll inv = ffpow(lim,mod - 2); for (int i = 0;i <= lim;++i) f[i] = f[i] * inv % mod; L.clear(); for (int i = 0;i <= l1 + l2 -2;++i) L.push_back(f[i]); return L; }
signed main() { read(n,m); jc[0] = 1; for (int i = 1;i <= n;++i) read(a[i]); for (int i = 1;i <= m;++i) jc[i] = jc[i-1] * i % mod; inv[m] = ffpow(jc[m],mod-2); for (int i = m-1;i >= 0;--i) inv[i] = (inv[i+1] * (i + 1)) %mod; for (int i = 1;i <= n;++i) { for (int k = 0;k <= a[i];++k) { tmp[i].push_back(C(k,a[i])* inv[a[i]-k] % mod); } } vector <ll> res = CDQNTT(1,n); ll ans = 0; for (int i = 0;i <= m;++i) { if (i & 1) ans = (ans + res[i] * (mod - 1) % mod * jc[m - i] % mod) % mod; else ans = (ans + res[i] * jc[m-i] % mod) % mod; } printf("%lld\n",ans % mod); return 0; }
|