概述
多项式求逆是多项式除法的基础。通过快速傅里叶变换和倍增,我们可以在$O(n \log n)$的复杂度内完成。
概念
定义$\deg A$表示多项式$A$的最高次数。
对于两个多项式$A$和$B$,有唯一的$A=QB+R$。其中$\deg R<\deg B$。可以写成$A\equiv R(\mod B)$。
那么使$AA’\equiv 1(\mod x^n)$的$A’$为多项式$A$在$\mod x^n$意义下的逆元,其中$\deg A’ \leqslant \deg A$。可以记做$A^{-1}$。
求法
比较显然的是,当$n=1$时,$A\equiv c(\mod x^n)$,那么$A^{-1}=c^{-1}$。
当$n>1$时,如果我们求出了$AA’\equiv 1(\mod x^{\lceil \frac{n}{2}\rceil})$,而因为$AA^{-1}\equiv1(\mod x^{\lceil \frac{n}{2}\rceil})$,那么一定有$AA^{-1}\equiv 1( \mod x^{\lceil \frac{n}{2}\rceil})$。
所以$A^{-1}-A’ \equiv 0(\mod x^{\lceil \frac{n}{2}\rceil})$。所以$(A^{-1})^2-2A^{-1}A’+(A’)^2\equiv 0 (\mod x^n)$。
乘上$A$,得到$A^{-1}-2A’+A(A’)^2\equiv 0(\mod x^n)$。所以$A^{-1}\equiv 2A’-A(A’)^2(\mod x^n)$。
时间复杂度为$O(n \log n)$。
代码
实现的时候要注意细节。
#include <bits/stdc++.h>
#define LL long long
using namespace std;
const LL Mod = 998244353;
const LL MaxN = 400010;
LL N, n, A[ MaxN ], B[ MaxN ];
LL Index[ MaxN ], Omega[ MaxN ], InvOmega[ MaxN ];
LL a[ MaxN ], b[ MaxN ];
LL FastPow( LL x, LL y ) {
if( !y ) return 1LL;
LL t = FastPow( x, y >> 1 );
t = t * t % Mod;
if( y & 1 ) t = t * x % Mod;
return t;
}
void Init() {
scanf( "%lld", &N );
for( LL i = 0; i < N; ++i )
scanf( "%lld", &A[ i ] );
for( n = 1; n < N << 1; n <<= 1 );
Omega[ 0 ] = 1; Omega[ 1 ] = FastPow( 3, ( Mod - 1 ) / n );
for( LL i = 2; i < n; ++i )
Omega[ i ] = Omega[ i - 1 ] * Omega[ 1 ] % Mod;
InvOmega[ n - 1 ] = FastPow( Omega[ n - 1 ], Mod - 2 );
for( LL i = n - 2; i >= 0; --i )
InvOmega[ i ] = InvOmega[ i + 1 ] * Omega[ 1 ] % Mod;
return;
}
void Print() {
for( LL i = 0; i < N; ++i )
printf( "%lld ", B[ i ] );
printf( "\n" );
return;
}
void NTT( LL L, LL *A, LL *Omega ) {
for( LL i = 0; i < L; ++i )
Index[ i ] = ( Index[ i >> 1 ] >> 1) | ( ( i & 1 ) * ( L >> 1 ) );
for( LL i = 0; i < L; ++i )
if( i < Index[ i ] )
swap( A[ i ], A[ Index [ i ] ] );
for( LL HalfLen = 1; HalfLen < L; HalfLen <<= 1 )
for( LL i = 0; i < L; i += HalfLen << 1 )
for( LL j = 0; j < HalfLen; ++j ) {
LL T = Omega[ n / 2 / HalfLen * j ] * A[ i + j + HalfLen ] % Mod;
LL t = A[ i + j ];
A[ i + j ] = ( t + T ) % Mod;
A[ i + j + HalfLen ] = ( t - T + Mod ) % Mod;
}
return;
}
void PolynomInverse( LL Len ) {
if( Len == 1 ) {
B[ 0 ] = FastPow( A[ 0 ], Mod - 2 );
return;
}
PolynomInverse( ( Len + 1 ) >> 1 );
LL len = 1;
for( ; len < Len << 1; len <<= 1 );
for( LL i = 0; i < Len; ++i ) a[ i ] = A[ i ];
for( LL i = 0; i < Len; ++i ) b[ i ] = B[ i ];
NTT( len, a, Omega ); NTT( len, b, Omega );
for( LL i = 0; i < len; ++i ) {
b[ i ] = b[ i ] * ( 2 - a[ i ] * b[ i ] % Mod ) % Mod;
b[ i ] = ( b[ i ] + Mod ) % Mod;
}
NTT( len, b, InvOmega );
LL Inv = FastPow( len, Mod - 2 );
for( LL i = 0; i < len; ++i )
b[ i ] = ( b[ i ] * Inv ) % Mod;
for( LL i = Len; i < len; ++i ) b[ i ] = 0;
memcpy( B, b, sizeof( b ) );
memset( a, 0, sizeof( a ) );
memset( b, 0, sizeof( b ) );
return;
}
int main() {
Init();
PolynomInverse( N );
Print();
return 0;
}