跳转至

Chapter 6 贪心

知识

alt text

alt text

alt text

alt text

alt text

模板

区间问题

常见做法: 将区间以 "左/右端点" 排序

(1) 区间选点:

  1. 将区间按 右端点 排序
  2. 遍历区间,如果该区间中不包含最后选的那个点,则选取该区间右端点;如果包含最后选的那个点,则跳过
  3. 输出所选点的个数
C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <iostream>
#include <algorithm>
using namespace std;

const int maxn = 1e5 + 10;

struct segment {
    int l, r;
}segments[maxn];

bool cmp(segment a, segment b)
{
    return a.r < b.r;
}

int n;

int main()
{
    cin >> n;
    for (int i=1; i<=n; i++) {
        int a, b;
        cin >> a >> b;
        segments[i].l = a;
        segments[i].r = b;
    }

    // 右端点排序
    sort(segments + 1, segments + 1 + n, cmp);

    int ans = 0;
    int end = -2e9; // 当前记录到的"最右侧"

    for (int i=1; i<=n; i++)
    {
        if (segments[i].l > end)
        {
            ans++;
            end = segments[i].r;
        }
    }

    cout << ans <<endl;
    return 0;
}

(2) 最大不相交区间数目:

  1. 将区间按右端点排序
  2. 遍历区间,如果该区间和上一个选的区间有重合,则跳过;如果和上一个选的区间没有重合,则选取该区间
  3. 输出所选区间的个数
C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <iostream>
#include <algorithm>
using namespace std;

const int maxn = 1e5 + 10;

struct segment {
    int l, r;
}segments[maxn];

bool cmp(segment a, segment b)
{
    return a.r < b.r;
}

int n;

int main()
{
    cin >> n;
    for (int i=1; i<=n; i++) {
        int a, b;
        cin >> a >> b;
        segments[i].l = a;
        segments[i].r = b;
    }

    // 右端点排序
    sort(segments + 1, segments + 1 + n, cmp);

    int ans = 0;
    int end = -2e9; // 当前记录到的"最右侧"

    for (int i=1; i<=n; i++)
    {
        if (segments[i].l > end)
        {
            ans++;
            end = segments[i].r;
        }
    }

    cout << ans <<endl;
    return 0;
}

(3) 区间分组:

  1. 将区间按左端点排序
  2. 依次遍历区间,如果当前区间能放到 之前的某个集合 中,则把该区间放到该集合;如果当前不能放到 任意一个之前的集合 中,则新开一个集合,把当前区间放到新开的集合中
  3. 集合的数量就是答案

关键步骤是第二步,如何判断当前区间能否放到之前的集合中。解决方法如下:

  1. 记录每个集合中保存的区间的最右侧端点,如果当前区间的左端点不和某个集合中保存的区间的最右侧端点相交,则当前区间不和该集合相交,能放到该集合中
  2. 也就是,我们只需判断当前区间的左端点 是否和 右侧端点最小的那个集合是否相交即可
  3. 为了快速找出右侧端点最小的那个集合,可以使 用小根堆保存每个集合的右端点
C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <iostream>
#include <algorithm>
#include <queue>
using namespace std;

const int maxn = 1e5 + 10;

struct segment {
    int l, r;
}segments[maxn];

bool cmp(segment a, segment b)
{
    return a.l < b.l;
}

int n;
// 小根堆来存放每个"会议室"的结束时间
priority_queue<int, vector<int>, greater<int>> min_heap;

int main()
{
    cin >> n;
    for (int i=1; i<=n; i++) {
        int a, b;
        cin >> a >> b;
        segments[i].l = a;
        segments[i].r = b;
    }

    // 左端点排序
    sort(segments+1, segments+1+n, cmp);

    for (int i=1; i<=n; i++)
    {
        if (min_heap.size() == 0 or min_heap.top() >= segments[i].l)
        {
            // "会议室"得新开一个
            min_heap.push(segments[i].r);
        }
        else
        {
            // "会议室"可以复用
            min_heap.pop();
            min_heap.push(segments[i].r);
        }
    }

    cout << min_heap.size() <<endl;
    return 0;
}

(4) 区间覆盖:

给定 N 个区间 [ai, bi] 以及一个区间 [s, t],请你选择尽量少的区间,将指定区间完全覆盖

  1. 将所有区间按照左端点从小到大进行排序
  2. 从前往后枚举每个区间,在所有能覆盖start的区间中,选择右端点的最大区间,然后将start更新成右端点的最大值
  3. 这一步用到了贪心决策
C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;

const int N = 100010; 

struct range {
    int l, r;
}R[N];

bool cmp(range a, range b) {
    return a.l < b.l;
}


int main()
{
    int st, ed;
    cin >> st >> ed;
    int n; 
    cin >> n;
    for (int i = 0; i < n; ++i)
    {
        int x, y; 
        cin >> x >> y;
        R[i] = {x, y};
    }

    // 左端点排序
    sort(R, R+n, cmp);

    int res = 0; 
    bool success = false;
    for (int i = 0; i < n; ++i)
    {
         int j = i, right = 0xc0c0c0c0;
        /*判断左端点在st之前的区间,循环找到最大右端点,如果右端点也在st之前,说明无法覆盖*/
        while (j < n && R[j].l <= st)
        {
            right = max(right, R[j].r);
            j++;
        }

        /*如果右端点也在st之前,说明无法覆盖*/
        if (right < st)
        {
            res = -1;
            break;
        }

        /*每循环一次,没有在前面跳出的话,说明找到了一个区间,res++*/
        res++;

        /*如果这个区间右端点能覆盖end,说明能覆盖*/
        if (right >= ed)
        {
            success = true;
            break;
        }

        /*把start更新成right,保证后面的区间适合之前的区间有交集,从而形成对整个序列的覆盖*/
        st = right;
        i = j - 1;
    }

    /*如果遍历了所有的数组,还是没有覆盖最后的end,说明不能成功*/
    if (!success) res = -1;
    cout << res <<endl;
    return 0;
}

(5) 区间合并:

虽然不是贪心思想,但也是一道经典的区间问题,放在这里一块整理吧

C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

typedef pair<int, int> PII; // 默认是按照左端点排序

void merge(vector<PII> &segs)
{
    vector<PII> res;

    sort(segs.begin(), segs.end());

    int st = -2e9, ed = -2e9; // 初始化
    for (auto seg : segs)
        if (ed < seg.first) // 需要另立门户
        {
            if (st != -2e9) res.push_back({st, ed});
            st = seg.first, ed = seg.second;
        }
        else ed = max(ed, seg.second); // 直接尾加

    if (st != -2e9) res.push_back({st, ed});

    segs = res;
}

int main()
{
    int n;
    cin >> n;

    vector<PII> segs;
    for (int i = 0; i < n; i ++ )
    {
        int l, r;
        cin >> l >> r;
        segs.push_back({l, r});
    }

    merge(segs);

    cout << segs.size() << endl;

    return 0;
}

Huffman树

合并果子

经典哈夫曼树的模型,每次合并重量最小的两堆果子即可

使用小根堆维护所有果子,每次弹出堆顶的两堆果子,并将其合并,合并之后将两堆重量之和再次插入小根堆中

每次操作会将果子的堆数减一,一共操作 n−1 次即可将所有果子合并成1堆 (that's why while (min_heap.size() > 1))

C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#include <iostream>
#include <algorithm>
#include <queue>
using namespace std;

const int maxn = 1e4 + 10;

int n;
priority_queue<int, vector<int>, greater<int>> min_heap;

int main()
{
    cin >> n;
    for (int i=1; i<=n; i++)
    {
        int x;
        cin >> x;
        min_heap.push(x);
    }

    int ans = 0;
    while (min_heap.size() > 1)
    {
        int a = min_heap.top(); min_heap.pop();
        int b = min_heap.top(); min_heap.pop();
        ans += (a+b);
        min_heap.push(a+b);
    }

    cout << ans << endl;
    return 0;
}

排序不等式

排队打水

C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include <iostream>
#include <algorithm>

using namespace std;

typedef long long LL;

const int maxn = 1e5 + 10;

int n;
LL a[maxn];
LL s[maxn];

int main()
{
    cin >> n;
    for (int i=1; i<=n; i++) cin >> a[i];

    sort(a+1, a+1+n);

    // // 这样也可以:
    // for (int i=1; i<=n; i++) s[i] = s[i-1] + a[i];

    // LL sum = 0;
    // for (int i=1; i<=n-1; i++)
    // {
    //     // 2: 等s[1]
    //     // 3: 等s[2]
    //     // ...
    //     // n: 等s[n-1]
    //     sum += s[i];
    // }
    LL sum = 0;
    for (int i=1; i<=n; i++) sum += a[i] * (n-i);

    cout << sum << endl;
    return 0;
}

绝对值不等式

仓库选址

C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#include <iostream>
#include <algorithm>

using namespace std;

const int N = 100010;

int n;
int q[N];

int main()
{
    cin >> n;

    for (int i = 1; i <= n; i ++) cin >> q[i];

    sort(q + 1, q + 1 + n);

    int res = 0;
    for (int i = 1; i <= n; i ++) res += abs(q[i] - q[(n+1) / 2]);

    cout << res << endl;

    return 0;
}

推公式

耍杂技的牛

C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include <iostream>
#include <algorithm>
using namespace std;

typedef long long ll;
typedef pair<int, int> PII;

const int N = 5e4 + 5;

PII a[N];

int main()
{
    int n;
    cin >> n;
    for(int i = 0; i < n; i ++ )
    {
        int x, y;
        scanf("%d %d", &x, &y);
        a[i].first = x + y;
        a[i].second = y;
    }

    sort(a, a + n);

    ll res = -1e18, sum = 0;

    for(int i = 0; i < n; i ++ )
    {
        sum -= a[i].second;
        res = max(res, sum);
        sum += a[i].first;
    }

    cout << res << endl;
    return 0;
}