[CSP-S模拟测试]:matrix(DP)

题目描述

求出满足以下条件的$n imes m$的$01$矩阵个数:
(1)第$i$行第$1~l_i$列恰好有$1$个$1$。
(2)第$i$行第$r_i~m$列恰好有$1$个$1$。
(3)每列至多有$1$个$1$。


输入格式

第一行两个整数$n,m$。接下来$n$行每行$2$个整数$l_i,r_i$。


输出格式

一行一个整数表示答案。对998244353取模。


样例

样例输入

2 6
2 4
5 6

样例输出

12


数据范围与提示

对于$20\%$的数据,$n,mleqslant 12$。
对于$40\%$的数据,$n,mleqslant 50$。
对于$70\%$的数据,$n,mleqslant 300$。
对于$100\%$的数据,$n,mleqslant 3000$,$1leqslant l_i<r_ileqslant m$。


题解

看到这道题,首先应该想到组合数,然后……

我一开始想的是容斥,正好能模过样例,但是忽然发现是一个多步容斥,无语……

$20\%$算法:

使劲搜就好了,别搜错了,记得别用clock(),递归函数里clock()返回值玄学(考场上被这东西干没了40分……)。

时间复杂度:$Theta(m^n)~Theta(m^{frac{n^2}{4}})$。

期望得分:$20$分。

$100\%$算法:

考虑$DP$,定义有点意思,定义$dp[i][j]$表示当前到了第$i$列,已经有$j$行在右侧区间放$1$的方案数。

下面来解释一下,注意右侧区间不是第$i$列的右侧,而是$r_i$,理解这东西我用了半个小时……

下面在来看一下如何转移:

  首先,要预处理一些东西,用一个$fl$数组表示在第$i$列以前结束的左区间的个数,$fr$数组表示在地$i$列以前开始的右区间的个数,注意这里的左区间和右区间也指的是$l_i$和$r_i$。

  然后就来明确两个值:$i-j-l[i-1]$就是第$i$列左侧$1$的个数,$l[i]-l[i-1]$就是在当前列结束的左区间的个数。

  那么,将有$4$种转移方式:

    $alpha.$如果$(i-j-l[i-1])<l[i]-l[i-1]$那么当前的$i$列$j$行以及往后更大的$j$都不可能贡献方案数,所以直接$break$即可。

    $eta.$$dp[i][j]=dp[i][j] imes A_{i-j-l[i-1]}^{l[i]-l[i-1]}$,需要用到这种转移是因为我们在下面两种转移的时候不方便执行这种操作。

    $gamma.$$dp[i-1][j]+=dp[i][j]$,这种就是不在这列放$1$,进行直接转移。

    $delta.$$dp[i][j]=dp[i-1][j-1] imes (r[i]-j+1)$这种就是放$1$,方案数为$r[i]-j+1$。

排列数$A$可以选择直接用$Theta(n imes m)$进行打表即可。

时间复杂度:$Theta(n imes m)$。

期望得分:$100$分。


代码时刻

$20\%$算法:

#include<bits/stdc++.h>
#define mod 998244353
using namespace std;
int n,m;
struct rec{int l,r;}q[3001];
bool vis[3001];
int ans;
int lmax,rmin=10000;
int qpow(int x,int y)
{
	int res=1;
	while(y)
	{
		if(y&1)res=res*x%mod;
		x=x*x%mod;
		y>>=1;
	}
	return res;
}
void dfs(int x)
{
	if(x>n)
	{
		ans=(ans+1)%mod;
		return;
	}
	for(int i=1;i<=q[x].l;i++)
		for(int j=q[x].r;j<=m;j++)
			if(!vis[i]&&!vis[j])
			{
				vis[i]=vis[j]=1;
				dfs(x+1);
				vis[i]=vis[j]=0;
			}
}
int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)
	{
		scanf("%d%d",&q[i].l,&q[i].r);
		lmax=max(lmax,q[i].l);
		rmin=min(rmin,q[i].r);
	}
	dfs(1);
	printf("%d",ans);
	return 0;
}

$100\%$算法:

#include<bits/stdc++.h>
using namespace std;
int n,m;
int l[5000],r[5000],A[5000][5000],dp[5000][5000];
void pre_work()
{
	A[0][0]=1;
	for(int i=1;i<=m;i++)
	{
		A[i][0]=1;
		for(int j=1;j<=i;j++)
			A[i][j]=(1LL*A[i][j-1]*(i-j+1))%998244353;
	}
}
int main()
{
	scanf("%d%d",&n,&m);
	pre_work();
	for(int i=1;i<=n;i++)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		l[x]++;r[y]++;
	}
	for(int i=1;i<=m;i++)
	{
		l[i]+=l[i-1];
		r[i]+=r[i-1];
	}
	dp[1][0]=1;
	for(int i=1;i<=m;i++)
	{
		for(int j=0;j<=r[i];j++)
		{
			if(i-j<l[i])break;
			dp[i][j]=1LL*dp[i][j]*A[i-j-l[i-1]][l[i]-l[i-1]]%998244353;
			dp[i+1][j]=(dp[i+1][j]+dp[i][j])%998244353;
			dp[i+1][j+1]=(dp[i+1][j+1]+1LL*dp[i][j]*(r[i+1]-j)%998244353)%998244353;
		}
	}
	cout<<dp[m][n];
	return 0;
}

rp++