「CodeNote」 母函数

Posted by Dawn-K's Blog on June 26, 2020

母函数

参考资料

[toc]

概念

母函数是用于解决组合问题的. 常见的母函数有两种: 普通型和指数型. 前者适用于多重集组合的问题, 后者适合涉及到排列的问题.

其原理是通过加法和乘法原理来模拟物品组合的结果, 然后通过系数和次数来表现出来. 系数主要是可行方案的个数, 次数主要是指组合出来的总权重.

公式

通用公式

\[F(x) = \prod_{i=1}^{i=n}\sum_{j=start[i]}^{j=end[i]}x^{(v[i]*j)}\]

式中, start[i] 表示第i个物品能够选择的最少的数量(通常是0) end[i] 表示第i个物品能够选择的最多的数量(有可能是无穷), 无穷情况下就看题目要求, 进行剪枝, 少进行不必要的计算.(比如每个物品没有上限, 求凑出6个物品的某种方案, 那么就可以认定end[i]=6) v[i] 表示第i个物品的权重(比如重2克的砝码, v[i]就是2), 也就是选择一个此物品能够带来的贡献

指数公式

六个数字, 三个是1, 两个是6, 一个是8, 能组成多少四位数? 所以我们给出指数形母函数: \(F(x)= \prod_{i=1}^{n} \sum_{m\in{M_i}} \frac{x^m}{m!}\)

乍一看比较懵, 其实含义只是普通的除以一个排列数以去重而已. 应用到上例中, $Ge(x)=(1+x+\frac{x^2}{2!}+\frac{x^3}{3!})(1+x+\frac{x^2}{2!})(1+x) =1+3x+8 \frac{x^2}{2!}+19 \frac{x^3}{3!}+38 \frac{x^4}{4!} +60 \frac{x^5}{5!}+60 \frac{x^6}{6!} $ 所以有38种.

这个的原理其实很好理解, 就是相比于通用公式, 我们在最后的结果上除以对应项的阶乘, 等于放大了系数, 也就是说如果是四种不同的元素组成的话, 即使每个元素都只有一个, 那么最后的方案数也是4!, 但是我们还要考虑到, 排列中如果出现m个相同的, 那么最后的方案数也要除以m!, 这个体现在括号内部除以m!.

总结: 对于括号每一项, 要考虑如果重复出现, 则要除以阶乘以去重. 对于结果, 也要除以阶乘, 目的是放大系数, 使得实现排序的效果.

模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// 最后的结果存在result数组中,如果end[i]为无穷,则去掉二层循环中对其的判断
// 最终要达到的数值,也就是次数的上限,p可以稍微大个1 2
// 记得创建 start[] 和 numEnd[]并初始化
int P ;
// 实际上MAX = P+1即可
int MAX = P + 1;
// 项的个数
int factorNum ;
// 这两个数组其实开在这里不太好
int result[MAX], tmpResult[MAX];
// 虽然在这里一样,但是实际上可以把上面两个数组改成全局的,这里就可以起到优化效果了.
int sz = min((P + 1) * sizeof(int), sizeof(tmpResult));
memset(result, 0, sz);
result[0] = 1;
for (int i = 0; i < factorNum; ++i) {  // 逐个括号处理
    memset(tmpResult, 0, sz);
    // 遍历括号内部的序列
    for (int j = start[i]; j <= numEnd[i] && j * v[i] <= P; ++j) {
        // 遍历reslut数组,和已有的结果相乘,得到新的结果.
        for (int k = 0; k + j * v[i] <= P; ++k) {
            tmpResult[k + v[i] * j] += result[k];
        }
    }
    memcpy(result, tmpResult, sz);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 优化版,但是细节增加了
// 初始化a,因为有last,所以这里无需初始化其他位
a[0]=1;
int last=0;
for (int i=0;i<K;i++)
{
	int last2=min(last+n[i]*v[i],P);//计算下一次的last
	memset(b,0,sizeof(int)*(last2+1));//只清空b[0..last2]
	for (int j=n1[i];j<=n2[i]&&j*v[i]<=last2;j++)//这里是last2
		for (int k=0;k<=last&&k+j*v[i]<=last2;k++)//这里一个是last,一个是last2
			b[k+j*v[i]]+=a[k];
	memcpy(a,b,sizeof(int)*(last2+1));//b赋值给a,只赋值0..last2
	last=last2;//更新last
}