关于常系数线性递推...
_THIS_IS_START_OF_ARTICLE_
_THIS_IS_END_OF_ARTICLE_
----------------------------------------------------------------------------------
最近研究了一下这个问题...
最后实现的时候花了整整两天(当然其中大部分时间都在浪...)
主要还是不会的太多了..=_=感觉学到了好多新知识...
然后在学习关于多项式求逆的那一套理论的时候,发现Picks博客上就有这个问题的解法而且貌似比我的简单得多..=_=
不过我和他的方法貌似不一样(窝并看不懂他在说什么...)
所以也算是自己的一个发现吧
另外期间大部分我不能解决的问题都是问的Skydec...
首先是问题:
已知$f(n)$的递推式$f(n)=\sum_{i=0}^{k-1}a_{i}*f(n-k+i)$(给定$k$和向量$a$)
并且给出$f(0)$~$f(k-1)$
现在给定一个$n$,求$f(n)$
----------------------------------------------------------------------------------
$O(lg(n)*k^3))$做法:
可以使用矩阵乘法
首先可以构造出初始向量$a=(f(0),f(1)...f(k-1))^T$
再构造出转移矩阵$A$:
$\begin{pmatrix}0 & 1 & 0 & \cdots & 0 \\0 & 0 & 1 & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\0 & 0 & 0 & \cdots & 1 \\a_0 & a_1 & a_2 & \cdots & a_{k-1} \\ \end{pmatrix}$
然后$A^{n}*a$的第$1$行第$1$列就是答案
然后就可以快速幂辣
----------------------------------------------------------------------------------
$O(lg(n)*k^2))$做法:
首先得到了$f(n)=\sum_{i=0}^{k-1}a_{i}*f(n-2^{0}*k+i)$ ...式$1$
把式$1$代入式$1$:
$f(n)=\sum_{i=0}^{k-1}a_{i}*\sum_{j=0}^{k-1}a_{j}f(n-2^{1}*k+i+j)$
可以暴力求,最后得到的式子共有$2*k$项,即:
$f(n)=\sum_{i=0}^{2*k-1}a^{'}_{i}f(n-2^{1}*k+i)$...式$1.5$
然后把式$1$代入式$1.5$的$f(n-2^{1}*k+k)$~$f(n-2^{1}*k+2*k-1)$这几项中,就得到了
$f(n)=\sum_{i=0}^{k-1}b_{i}*f(n-2^{1}*k+i)$ ...式$2$
这样不断地把自己代入自己,然后把式$1$代入自己的后$k$项,就可以的到式$p$:(我把$f(n-2^{p+1}*k+i)$叫做第$i$项)
$f(n)=\sum_{i=0}^{k-1}b_{i}*f(n-2^{p}*k+i)$ ...式$p$
这样就是一个倍增的过程,每次可以把$k$的系数乘以$2$
然后$n$是奇数的时候,需要往后推一项,把式$1$带进去即可
----------------------------------------------------------------------------------
$O(lg(n)*lg(k)*k))$做法:
重头戏来辣=w=
考虑上面那个方法的优化
主要分为两步:自己代入自己和把后$k$项展开到前$k$项
首先看自己代入自己这个过程:
$f(n)=\sum_{i=0}^{k-1}a_{i}*\sum_{j=0}^{k-1}a_{j}f(n-2^{p+1}*k+i+j)$
聪明的读者已经发现,这就是个卷积,所以做一个多项式乘法就好了...
现在假设已经求到了$f(n)=\sum_{i=0}^{k-1}c_{i}*f(n-2^{1}*k+i)$
另外为了方便表示,不放把$a$翻转一下然后移动一下位置(设这样得到的向量是$a^{'}$):
$f(n)=\sum_{i=1}^{k}a^{'}_{i}*f(n-i)$
然后来考虑第二步,这一步中,显然$k$~$2*k-1$这些项会对$0$~$k-1$这些项产生贡献
聪明的读者不难发现,第s项($k \le s \lt 2*k$)对第t项($0 \le t \lt k$)产生的贡献是:
$g(s,t)=c_{s}*\sum_{i=1}^{k}\sum_{b_{1},b_{2}...b_{i}且\sum_{j=1}^{i} b_{i}=s-t且 b_{j} \ge 1 且b_{i} \ge k-t}\prod_{i=1}^{k}a^{'}_{b_{i}}$
而第$t$项最后得到的贡献就是$d_t=\sum_{k \le s \lt 2*k}g(s,t)$
听起来很麻烦?
先假设不考虑$b_{i}项$
不放把$t$加上$b_{i}$
因此有:
$c^{'}_{t}=c_{s}*\sum_{i=1}^{k}\sum_{b_{1},b_{2}...b_{i}且\sum_{j=1}^{i} b_{i}=s-t且 b_{j} \ge 1}\prod_{i=1}^{k}a^{'}_{b_{i}}$
其中$k \le t \lt 2*k$
这就简单多了
假设$A$是$a^{'}$的生成函数
设$A[k]$表示$A$第$k$次项系数
那么不难发现上式就等于
$c^{'}_{t}=c_{s}*(\sum_{i=0}^{k}A^i)[s-t]$
=$c_{s}*((\sum_{i=0}^{+\infty}A^i)$ $mod$ $x^{k})[s-t]$
=$c_{s}*(1/(1-A))$ $mod$ $x^{k})[s-t]$
这个显然是一个卷积,于是一个多项式求逆一个多项式乘法就可以求到了
现在求到了$c^{'}$,考虑将他的各项减去$b_{i} (b_{i} \ge k-t)$来得到最后$d$
因此:
$d_t=\sum_{k \le s \lt 2*k}c^{'}_{s}*a^{'}_{s-t}$
然后这也显然是个卷积,多项式乘法即可
至此问题就解决了,最后$c^{'}_{i}=c_{i}+d_{i}(0 \le i \lt k)$就是展开完之后的$f(n)$的表达式的前$k$项
只用到了多项式乘法和多项式求逆,这两个都是$O(lg(k)*k)$的
然后算上倍增的时间,最后的时间复杂度$O(lg(n)*lg(k)*k))$
最后,实现的时候,多项式求逆中会爆精度,所以只能用NTT做(或者用分块乘法...)...
----------------------------------------------------------------------------------
这里有一个很挫的实现:(只实现了最后一步,也就是把$f(n)$含有$2*k$项的表达式展开成只含有$k$项的表达式,这是最难也是最关键的一步)
(输入输出格式是:第一行输入$k$,然后输入$a^{'}$,然后输入有$2*n$项的$c$,输出展开之后只有$n$项的$c^{'}$)
#include <stdio.h> #include <string.h> #include <math.h> #include <time.h> typedef long long Long; typedef double Ld; const int N = 800000*2; const Ld PI = acos(-1); const int P = 998244353; const int W = 3; int pow(int a,int b) { a %= P; if(a < 0) a += P; int r = 1.; while(b) { if(b & 1) r = (1LL * r * a) % (1LL * P); a = (1LL * a * a) % (1LL * P); b >>= 1; } return r; } namespace NTT { int wn[N + 10]; int rev[N + 10]; void init(int n) { int log2n = 0; int nn = (n >> 1); while(nn) { log2n ++; nn >>= 1; } int num = 0; for(int i = 1;i <= n;i <<= 1) { wn[num++] = pow(W,(P-1)/i); } int x = 0; int y = 0; for(int i = 0;i < n;i++) { x = i; y = 0; for(int j = 1;j <= log2n;j++) { y <<= 1; y |= (x & 1); x >>= 1; } rev[i] = y; } } int buf[N + 10]; void NTT(int * a,int n,int s) { for(int i = 0;i < n;i++) buf[i] = a[rev[i]]; for(int i = 0;i < n;i++) a[i] = buf[i]; int t = 2; int div2 = 0; int l = 0; int num = 0; while(t <= n) { div2 = (t >> 1); l = n / t; int ww = wn[++num]; if(s) ww = pow(ww,P-2); for(int i = 0;i < n;i += t) { int w = 1; for(int j = 0;j < div2;j++) { int x = a[i+j]; int y = (1LL * a[i+j+div2] * w) % (1LL * P); a[i+j] = (x + y) % P; a[i+j+div2] = (x + P - y) % P; w = (1LL * w * ww) % (1LL * P); } } t <<= 1; } if(s) { int di = pow(n,P-2); for(int i = 0;i < n;i++) a[i] = (1LL * a[i] * di) % (1LL * P); } } void mul(int * a,int * b,int * c,int n) { init(2 * n); NTT(a,2 * n,0); NTT(b,2 * n,0); for(int i = 0;i < 2 * n;i++) c[i] = (1LL * a[i] * b[i]) % (1LL * P); NTT(c,2 * n,1); } int tmp[N + 10]; void get_inv(int * a,int * b,int t)//b = a^-1 mod x^t { if(t == 1) { b[0] = pow(a[0],P-2); return ; } get_inv(a,b,(t + 1) >> 1); int k = 0; for(k = 1;k <= (t << 1) + 3;k <<= 1); for(int i = 0;i < t;i++) tmp[i] = a[i]; for(int i = t;i < k;i++) tmp[i] = 0; init(k); NTT(tmp,k,0); NTT(b,k,0); for(int i = 0;i < k;i++) { int val = (1LL * b[i] * b[i]) % (1LL * P); tmp[i] = (1LL * tmp[i] * val) % (1LL * P); } NTT(tmp,k,1); NTT(b,k,1); for(int i = 0;i < k;i++) b[i] = (2LL * b[i] + P - tmp[i]) % P; for(int i = t;i < k;i++) b[i] = 0; } }; int pa[N + 10]; int a[N + 10]; int d[N + 10];//d[i] = c[i+n] int c[N + 10]; int b[N + 10]; int p[N + 10]; int main() { FILE * fin = fopen("test.in","r"); FILE * fout = fopen("test.out","w"); int n = 0; fscanf(fin,"%d",&n); for(int i = 1;i <= n;i++) { fscanf(fin,"%d",a + i); a[i] = ((a[i] % P) + P) % P; } for(int i = 0;i < 2 * n;i++) { fscanf(fin,"%d",c + i); c[i] = ((c[i] % P) + P) % P; if(i >= n) d[i] = c[i]; } { for(int i = 1;i <= n;i++) pa[i] = -a[i]; pa[0] ++; NTT::get_inv(pa,b,n); } { for(int i = 0;i <= (n>>1);i++) { if(i < n-i) b[i] ^= b[n-i] ^= b[i] ^= b[n-i]; } int e = 3 * n; int q = 1; while(q <= e) q <<= 1; NTT::mul(d,b,p,q); for(int i = 3 * n - 1;i >= 2 * n;i--) { p[i-n] = p[i]; p[i] = 0; } } { memset(b,0,sizeof(b)); for(int i = 0;i <= n;i++) { if(i < 2*n-i) p[i] ^= p[2*n-i] ^= p[i] ^= p[2*n-i]; } int e = 6 * n; int q = 1; while(q <= e) q <<= 1; NTT::mul(a,p,b,q); for(int i = 0;i < 2 * n;i++) c[i] += b[2*n-i]; } for(int i = 0;i < n;i++) { if(i == n-1) fprintf(fout,"%d\n",((c[i] % P) + P) % P); else fprintf(fout,"%d ",((c[i] % P) + P) % P); } fclose(fin); fclose(fout); return 0; }
----------------------------------------------------------------------------------
2016年4月26日 21:40
大爷好强%%%