kurarrr's memooo

主に競プロ備忘録

ARC 067 E - Grouping

問題概要

N人をいくつかのグループに分ける. ただし以下の制約を満たす.
グループに含まれる人数iは a <= i <= bである.
i 人が含まれるグループの数を f_i とした時, c <= f_i <= d.
分け方の総数を求めよ.
1 <= N <= 103

解法

sum(a<=e<=b)f_e*e = N を満たす f_e が決まれば (e人のグループがf_e個ずつあるということ),あとはコンビネーションで求まる.
雰囲気がDPっぽいのでDPでやる.
dp[i][j] := (f
(a+i+1)までを決めて, 合計 j 人の時の組み合わせ) とすると,
dp[i][j] += dp[i-1][j] (f=0とした時)
dp[i][j] += sum_(c<=f<=d) (N-j+ef)! / (f! * (N-j)! * (e!)f )
(N-j+ef人の中からe人のグループをf個取ってきて残りをN-j人にした(j人選んだ状態にした))
ここで,f個のグループは区別しないことに注意.
割り算はフェルマーの小定理使って1e9+7に関する逆元求めるやつ使う.
一見O(N3)ぽいが, e*f <= j でなければならないのでO(N2 * logN)ですむ.

メモ

配るDPの方が楽そうだったけどなぜかエラー出たしあんまり好きじゃないんだよなあ.

逆元を求めるのが重いので多用しない方が良い.
o ... imod(a) * imod(b)
x ... imod(a*b)

コード

#include "bits/stdc++.h"

#define ALL(g) (g).begin(),(g).end()
#define REP(i, x, n) for(int i = x; i < n; i++)
#define rep(i,n) REP(i,0,n)
#define RREP(i, x, n) for(int i = x; i >= n; i--)
#define rrep(i, n) RREP(i,n,0)
#define pb push_back

using namespace std;

using ll = long long;
using P = pair<int,int>;
using Pl = pair<ll,int>;

const int mod=1e9+7,INF=1<<30;
const double EPS=1e-12,PI=3.1415926535897932384626;
const ll lmod = 1e9+7,LINF=1LL<<60;
const int MAX_N = 1003;

vector<ll> fact;

void init(ll N){
  fact.resize(N+1);
  fact[0]=1;
  rep(i,N)  fact[i+1]=((i+1)*fact[i])%lmod;
}

ll ppow(ll a,ll b){
  // aのb乗を求める
  ll res = 1;
  while(b){
    if(b%2) res=(res*a)%lmod;
    a=(a*a)%lmod;
    b>>=1;
  }
  return res;
}

ll imod(ll n){
  // nのmodの逆元を求める
  ll P=lmod;
  return ppow(n,P-2);
}

ll comb_mod(ll n,ll k){
  //nまで埋めた階乗テーブルを渡す
  return (fact[n] * imod(fact[k]) % lmod) * imod(fact[n-k]) % lmod ;
  //nCk % mod を返す
}

int N,a,b,c,d;
ll dp[MAX_N][MAX_N];

int main(){
  cin >> N >> a >> b >> c >> d;
  init(N+2);
  dp[0][0] = 1LL;
  REP(i,1,b-a+2){
    rep(j,N+1){
      dp[i][j] += dp[i-1][j];
      int e = a+i-1;
      for(int f=c; f<=d && e*f<=j;f++){
        dp[i][j] += dp[i-1][j-e*f]*fact[N-j+e*f]%lmod*imod(fact[f]*fact[N-j]%lmod*ppow(fact[e],f)%lmod);
        dp[i][j] %= lmod;
      }
    }
  }
  cout << dp[b-a+1][N] << endl;
  return 0;
}