关于常系数线性递推...
_THIS_IS_START_OF_ARTICLE_
_THIS_IS_END_OF_ARTICLE_
----------------------------------------------------------------------------------
最近研究了一下这个问题...
最后实现的时候花了整整两天(当然其中大部分时间都在浪...)
主要还是不会的太多了..=_=感觉学到了好多新知识...
然后在学习关于多项式求逆的那一套理论的时候,发现Picks博客上就有这个问题的解法而且貌似比我的简单得多..=_=
不过我和他的方法貌似不一样(窝并看不懂他在说什么...)
所以也算是自己的一个发现吧
另外期间大部分我不能解决的问题都是问的Skydec...
首先是问题:
已知f(n)的递推式f(n)=∑k−1i=0ai∗f(n−k+i)(给定k和向量a)
并且给出f(0)~f(k−1)
现在给定一个n,求f(n)
----------------------------------------------------------------------------------
O(lg(n)∗k3))做法:
可以使用矩阵乘法
首先可以构造出初始向量a=(f(0),f(1)...f(k−1))T
再构造出转移矩阵A:
(010⋯0001⋯0⋮⋮⋮⋱⋮000⋯1a0a1a2⋯ak−1)
然后An∗a的第1行第1列就是答案
然后就可以快速幂辣
----------------------------------------------------------------------------------
O(lg(n)∗k2))做法:
首先得到了f(n)=∑k−1i=0ai∗f(n−20∗k+i) ...式1
把式1代入式1:
f(n)=∑k−1i=0ai∗∑k−1j=0ajf(n−21∗k+i+j)
可以暴力求,最后得到的式子共有2∗k项,即:
f(n)=∑2∗k−1i=0a′if(n−21∗k+i)...式1.5
然后把式1代入式1.5的f(n−21∗k+k)~f(n−21∗k+2∗k−1)这几项中,就得到了
f(n)=∑k−1i=0bi∗f(n−21∗k+i) ...式2
这样不断地把自己代入自己,然后把式1代入自己的后k项,就可以的到式p:(我把f(n−2p+1∗k+i)叫做第i项)
f(n)=∑k−1i=0bi∗f(n−2p∗k+i) ...式p
这样就是一个倍增的过程,每次可以把k的系数乘以2
然后n是奇数的时候,需要往后推一项,把式1带进去即可
----------------------------------------------------------------------------------
O(lg(n)∗lg(k)∗k))做法:
重头戏来辣=w=
考虑上面那个方法的优化
主要分为两步:自己代入自己和把后k项展开到前k项
首先看自己代入自己这个过程:
f(n)=∑k−1i=0ai∗∑k−1j=0ajf(n−2p+1∗k+i+j)
聪明的读者已经发现,这就是个卷积,所以做一个多项式乘法就好了...
现在假设已经求到了f(n)=∑k−1i=0ci∗f(n−21∗k+i)
另外为了方便表示,不放把a翻转一下然后移动一下位置(设这样得到的向量是a′):
f(n)=∑ki=1a′i∗f(n−i)
然后来考虑第二步,这一步中,显然k~2∗k−1这些项会对0~k−1这些项产生贡献
聪明的读者不难发现,第s项(k≤s<2∗k)对第t项(0≤t<k)产生的贡献是:
g(s,t)=cs∗∑ki=1∑b1,b2...bi且∑ij=1bi=s−t且bj≥1且bi≥k−t∏ki=1a′bi
而第t项最后得到的贡献就是dt=∑k≤s<2∗kg(s,t)
听起来很麻烦?
先假设不考虑bi项
不放把t加上bi
因此有:
c′t=cs∗∑ki=1∑b1,b2...bi且∑ij=1bi=s−t且bj≥1∏ki=1a′bi
其中k≤t<2∗k
这就简单多了
假设A是a′的生成函数
设A[k]表示A第k次项系数
那么不难发现上式就等于
c′t=cs∗(∑ki=0Ai)[s−t]
=cs∗((∑+∞i=0Ai) mod xk)[s−t]
=cs∗(1/(1−A)) mod xk)[s−t]
这个显然是一个卷积,于是一个多项式求逆一个多项式乘法就可以求到了
现在求到了c′,考虑将他的各项减去bi(bi≥k−t)来得到最后d
因此:
dt=∑k≤s<2∗kc′s∗a′s−t
然后这也显然是个卷积,多项式乘法即可
至此问题就解决了,最后c′i=ci+di(0≤i<k)就是展开完之后的f(n)的表达式的前k项
只用到了多项式乘法和多项式求逆,这两个都是O(lg(k)∗k)的
然后算上倍增的时间,最后的时间复杂度O(lg(n)∗lg(k)∗k))
最后,实现的时候,多项式求逆中会爆精度,所以只能用NTT做(或者用分块乘法...)...
----------------------------------------------------------------------------------
这里有一个很挫的实现:(只实现了最后一步,也就是把f(n)含有2∗k项的表达式展开成只含有k项的表达式,这是最难也是最关键的一步)
(输入输出格式是:第一行输入k,然后输入a′,然后输入有2∗n项的c,输出展开之后只有n项的c′)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 | #include <stdio.h> #include <string.h> #include <math.h> #include <time.h> typedef long long Long; typedef double Ld; const int N = 800000*2; const Ld PI = acos (-1); const int P = 998244353; const int W = 3; int pow ( int a, int b) { a %= P; if (a < 0) a += P; int r = 1.; while (b) { if (b & 1) r = (1LL * r * a) % (1LL * P); a = (1LL * a * a) % (1LL * P); b >>= 1; } return r; } namespace NTT { int wn[N + 10]; int rev[N + 10]; void init( int n) { int log2n = 0; int nn = (n >> 1); while (nn) { log2n ++; nn >>= 1; } int num = 0; for ( int i = 1;i <= n;i <<= 1) { wn[num++] = pow (W,(P-1)/i); } int x = 0; int y = 0; for ( int i = 0;i < n;i++) { x = i; y = 0; for ( int j = 1;j <= log2n;j++) { y <<= 1; y |= (x & 1); x >>= 1; } rev[i] = y; } } int buf[N + 10]; void NTT( int * a, int n, int s) { for ( int i = 0;i < n;i++) buf[i] = a[rev[i]]; for ( int i = 0;i < n;i++) a[i] = buf[i]; int t = 2; int div2 = 0; int l = 0; int num = 0; while (t <= n) { div2 = (t >> 1); l = n / t; int ww = wn[++num]; if (s) ww = pow (ww,P-2); for ( int i = 0;i < n;i += t) { int w = 1; for ( int j = 0;j < div2;j++) { int x = a[i+j]; int y = (1LL * a[i+j+div2] * w) % (1LL * P); a[i+j] = (x + y) % P; a[i+j+div2] = (x + P - y) % P; w = (1LL * w * ww) % (1LL * P); } } t <<= 1; } if (s) { int di = pow (n,P-2); for ( int i = 0;i < n;i++) a[i] = (1LL * a[i] * di) % (1LL * P); } } void mul( int * a, int * b, int * c, int n) { init(2 * n); NTT(a,2 * n,0); NTT(b,2 * n,0); for ( int i = 0;i < 2 * n;i++) c[i] = (1LL * a[i] * b[i]) % (1LL * P); NTT(c,2 * n,1); } int tmp[N + 10]; void get_inv( int * a, int * b, int t) //b = a^-1 mod x^t { if (t == 1) { b[0] = pow (a[0],P-2); return ; } get_inv(a,b,(t + 1) >> 1); int k = 0; for (k = 1;k <= (t << 1) + 3;k <<= 1); for ( int i = 0;i < t;i++) tmp[i] = a[i]; for ( int i = t;i < k;i++) tmp[i] = 0; init(k); NTT(tmp,k,0); NTT(b,k,0); for ( int i = 0;i < k;i++) { int val = (1LL * b[i] * b[i]) % (1LL * P); tmp[i] = (1LL * tmp[i] * val) % (1LL * P); } NTT(tmp,k,1); NTT(b,k,1); for ( int i = 0;i < k;i++) b[i] = (2LL * b[i] + P - tmp[i]) % P; for ( int i = t;i < k;i++) b[i] = 0; } }; int pa[N + 10]; int a[N + 10]; int d[N + 10]; //d[i] = c[i+n] int c[N + 10]; int b[N + 10]; int p[N + 10]; int main() { FILE * fin = fopen ( "test.in" , "r" ); FILE * fout = fopen ( "test.out" , "w" ); int n = 0; fscanf (fin, "%d" ,&n); for ( int i = 1;i <= n;i++) { fscanf (fin, "%d" ,a + i); a[i] = ((a[i] % P) + P) % P; } for ( int i = 0;i < 2 * n;i++) { fscanf (fin, "%d" ,c + i); c[i] = ((c[i] % P) + P) % P; if (i >= n) d[i] = c[i]; } { for ( int i = 1;i <= n;i++) pa[i] = -a[i]; pa[0] ++; NTT::get_inv(pa,b,n); } { for ( int i = 0;i <= (n>>1);i++) { if (i < n-i) b[i] ^= b[n-i] ^= b[i] ^= b[n-i]; } int e = 3 * n; int q = 1; while (q <= e) q <<= 1; NTT::mul(d,b,p,q); for ( int i = 3 * n - 1;i >= 2 * n;i--) { p[i-n] = p[i]; p[i] = 0; } } { memset (b,0, sizeof (b)); for ( int i = 0;i <= n;i++) { if (i < 2*n-i) p[i] ^= p[2*n-i] ^= p[i] ^= p[2*n-i]; } int e = 6 * n; int q = 1; while (q <= e) q <<= 1; NTT::mul(a,p,b,q); for ( int i = 0;i < 2 * n;i++) c[i] += b[2*n-i]; } for ( int i = 0;i < n;i++) { if (i == n-1) fprintf (fout, "%d\n" ,((c[i] % P) + P) % P); else fprintf (fout, "%d " ,((c[i] % P) + P) % P); } fclose (fin); fclose (fout); return 0; } |
----------------------------------------------------------------------------------
2016年4月26日 21:40
大爷好强%%%