跳转至

线段树 1

操作:

  • pushup: 根据子节点 更新 父节点
    • eg. sum = l.sum + r.sum
  • pushdown: 父节点的修改信息,下传到子节点
    • alias: 懒标记 / 延迟标记

线段树是算法竞赛中常用的用来维护 区间信息 的数据结构

线段树可以在 \(O(logN)\) 的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作

(1) 一般用来:

  • 单点修改
  • 区间修改: 见下节课 "线段树-2", 这里不提
  • 区间查询(即区间求和,求区间 max ,求区间 min ,区间 gcd ...)

但是,线段树所维护的信息,需要满足区间加法:

Text Only
1
区间加法:如果一个区间 [l,r](线段树中一个点表示一个区间)满足区间加法的意思是一个区间 [l,r] 的线段树维护的信息(即区间最大值,区间最小值,区间和,区间 gcd 等),可以由两个区间 [l,mid] 和 [mid+1,r] 合并而来

(2) 基础操作:

  1. pushup(u)
  2. pushdown(u)
  3. build(): 将一段区间 初始化 成线段树
  4. modify(): 修改
    1. 单点: 很简单
    2. 一段区间: 很难。要结合 pushdown(), 懒标记
  5. query()

(3) 存储:

  • 空间开设: 如果有n个点, 开的空间是 4n
  • 存储方式: 采用 heap 存储
    • 一棵线段树的 root 的编号是 1
    • 设一个不为根的节点编号 x ,则这个点的父节点: x/2(向上取整), 子节点分别是 2x2x+1
    • 父节点: x>>1
    • 左儿子: x<<1
    • 右儿子: x<<1 | 1

表示方式:

C++
1
2
3
4
struct node {
    int l, r; // node表示的tree区间左右端点
    int v; // ... 区间内需要表示的"属性", 因地制宜
}tr[4*N];

(4) 基本概念

(5) 模板:

C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// pushup
// 由子节点的信息, 来"向上更新"父节点的信息
// 需要因地制宜

// (1) 1275
void pushup(int u)  // 由子节点的信息,来计算父节点的信息
{
    tr[u].v = max(tr[u << 1].v, tr[u << 1 | 1].v);
}

// (2) 245
void pushup(Node &u, Node &l, Node &r)
{
    // u父亲, l左儿子, r右儿子
    u.sum = l.sum + r.sum;
    u.lmax = max(l.lmax, l.sum + r.lmax);
    u.rmax = max(r.rmax, r.sum + l.rmax);
    u.tmax = max(max(l.tmax, r.tmax), l.rmax + r.lmax);
}

void pushup(int u) // 当前节点u, 要从它的两个儿子处更新
{
    pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
// build
// (1) 主函数调用初始化:
build(1, 1, n); // 1: 从"顶层"树开始, 表示总体区间; [1, n]: 从1到n

// (2) 构建函数的定义:
// 当前正在下标为u的点, 这个点表示的区间是 [l,r]
void build(int u, int l, int r)
{
    // u: 节点编号 l: 区间左端点 r: 区间右端点
    tr[u] = {l, r}; // 记得存储当前点表示的区间,否则你会调上一整天!!!

    if (l == r) return;

    int mid = l + r >> 1;
    build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}
C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
// modify
// 以二分查找的形式找到要修改的点, 然后把"向上"的链都修改
void modify(int u, int x, int v) // 将 x 处的值修改成 v
{
    if (tr[u].l == x && tr[u].r == x) tr[u].v = v;
    else
    {
        // 先分治得到结果
        int mid = tr[u].l + tr[u].r >> 1;
        if (x <= mid) modify(u << 1, x, v);
        else modify(u << 1 | 1, x, v);
        // 再将结果向上传递给父节点, 逐层更新
        pushup(u);
    }
}
C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
// query
// 对于nodeU, 查询区间[l, r]
int query(int u, int l, int r)
{
    // 树中节点, 已经被完全包含在[l, r]中了, 直接return
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].v;

    int mid = tr[u].l + tr[u].r >> 1;
    int v = 0;
    if (l <= mid) v = query(u << 1, l, r);
    if (r > mid) v = max(v, query(u << 1 | 1, l, r));

    return v;
}

注意 modify()query() 是否需要 "将结果向上传递给父节点, 逐层更新"

1275 最大数

分析, 这个题目的两个操作等价于:

  1. 在某一个位置, 修改一个数 (单点修改)
  2. 求某个区间内的最大值

构建, tree node:

C++
1
2
3
4
5
// 求 "某个区间" 内的 "某个属性"
struct node {
    int l, r;
    int v; // 维护的是 区间[l, r] 的 "属性" - 最大值
}tr[N * 4]; // N 个点, 空间开 4*N

代码:

C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

typedef long long LL;

const int N = 200010;

int m, p;
struct Node
{
    int l, r;
    int v;  // 区间[l, r]中的最大值
}tr[N * 4];

/*
review 堆heap的数组构建表示法

- 父节点: x>>1.
- 左son: 2x.   alias x<<1
- 右son: 2x+1. alias x<<1 | 1
*/

void pushup(int u)  // 由子节点的信息,来计算父节点的信息
{
    tr[u].v = max(tr[u << 1].v, tr[u << 1 | 1].v);
}

void build(int u, int l, int r)
{
    // u: 节点数量 l: 区间左端点 r: 区间右端点
    tr[u] = {l, r};
    if (l == r) return;
    int mid = l + r >> 1;
    build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}

int query(int u, int l, int r)
{
    // 树中节点, 已经被完全包含在[l, r]中了, 直接return
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].v;

    int mid = tr[u].l + tr[u].r >> 1;
    int v = 0;
    if (l <= mid) v = query(u << 1, l, r);
    if (r > mid) v = max(v, query(u << 1 | 1, l, r));

    return v;
}

void modify(int u, int x, int v) // 将 x 处的值修改成 v
{
    if (tr[u].l == x && tr[u].r == x) tr[u].v = v;
    else
    {
        // 先分治得到结果
        int mid = tr[u].l + tr[u].r >> 1;
        if (x <= mid) modify(u << 1, x, v);
        else modify(u << 1 | 1, x, v);
        // 再将结果向上传递给父节点, 逐层更新
        pushup(u);
    }
}


int main()
{
    int n = 0, last = 0;
    scanf("%d%d", &m, &p);
    build(1, 1, m);

    int x;
    char op[2];
    while (m -- )
    {
        scanf("%s%d", op, &x);
        if (*op == 'Q')
        {
            last = query(1, n - x + 1, n);
            printf("%d\n", last);
        }
        else
        {
            modify(1, n + 1, ((LL)last + x) % p);
            n ++ ;
        }
    }

    return 0;
}

245 你能回答这些问题吗

纯模板线段树

分析, 这个题目的两个操作等价于:

  1. 在某一个位置, 修改一个数 (单点修改)
  2. 求某个区间内的"连续最大子段和"

构建, tree node:

C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
// 求 "某个区间" 内的 "某个属性"
struct node {
    int l, r;
    int tmux; // 最大连续子段和

    // 不够! 横跨左右子区间的最大子段和 = 
    // 左子区间最大后缀和 + 右子区间最大前缀和
    int rmax; // 最大后缀和
    int lmax; // 最大前缀和

    // 还不够! 分析见下
    int sum; // 区间和

    // 现在ok了

}tr[N * 4]; // N 个点, 空间开 4*N

考虑一下更新的方式:

(1) tmax:

C++
1
2
3
4
5
tmax = max (
    left_son.tmax,
    right_son.tmax,
    left_son.rmax + right_son.lmax
)

(2) lmax:

C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
/*
(1) 没跨过左半边
|-----|
|-------------||-------------|
*/
lmax = left_son.lmax
/*
(2) 跨过左半边了
|------------------|
|-------------||-------------|
*/
lmax = left_son.sum + right_son.lmax

(3) rmax: 同理

(4) sum: sum = left_son.sum + right_son.sum

代码:

C++
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

const int N = 500010;

int n, m;
int w[N];
struct Node
{
    int l, r;
    int sum, lmax, rmax, tmax;
}tr[N * 4];

void pushup(Node &u, Node &l, Node &r)
{
    // u父亲, l左儿子, r右儿子
    u.sum = l.sum + r.sum;
    u.lmax = max(l.lmax, l.sum + r.lmax);
    u.rmax = max(r.rmax, r.sum + l.rmax);
    u.tmax = max(max(l.tmax, r.tmax), l.rmax + r.lmax);
}

void pushup(int u) // 当前节点u, 要从它的两个儿子处更新
{
    pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}

void build(int u, int l, int r) // 对子树u, 区间 [l,r], 构建线段树
{
    if (l == r) tr[u] = {l, r, w[r], w[r], w[r], w[r]};
    else
    {
        tr[u] = {l, r}; // 初始化
        int mid = l + r >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}

void modify(int u, int x, int v) // 将x处的值改成v
{
    if (tr[u].l == x && tr[u].r == x) tr[u] = {x, x, v, v, v, v};
    else
    {
        int mid = tr[u].l + tr[u].r >> 1;
        if (x <= mid) modify(u << 1, x, v); // 左半边
        else modify(u << 1 | 1, x, v); // 右半边
        pushup(u); // 向上传递给父节点
    }
}

Node query(int u, int l, int r)
{
    // 查的node编号: u
    // 查询区间: [l, r]
    if (tr[u].l >= l && tr[u].r <= r) return tr[u];
    else
    {
        int mid = tr[u].l + tr[u].r >> 1;
        if (r <= mid) return query(u << 1, l, r);
        else if (l > mid) return query(u << 1 | 1, l, r);
        else
        {
            auto left = query(u << 1, l, r);
            auto right = query(u << 1 | 1, l, r);
            Node res;
            pushup(res, left, right); // 向上传递给父节点
            return res;
        }
    }
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]);

    // 根据这两个信息构建线段树:
    // - u=1: 顶层最长的那个区间, 即: 总区间
    // - 区间 [1, n]
    build(1, 1, n);

    int k, x, y;
    while (m -- )
    {
        scanf("%d%d%d", &k, &x, &y);
        if (k == 1)
        {
            if (x > y) swap(x, y);
            // u=1: 顶层最长的那个区间, 即: 总区间
            printf("%d\n", query(1, x, y).tmax);
        }
        else modify(1, x, y); // u=1: 顶层最长的那个区间, 即: 总区间
    }

    return 0;
}

求最大值的方法:

C++
1
2
3
4
// 用 { } 语法
maxv = max({l.tmax, r.tmax, l.rmax + r.lmax});
// 先比两个, 再跟第三个比 (比较保险)
maxv = max( max(l.tmax, r.tmax), l.rmax + r.lmax);

246 区间最大公约数

C++
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

typedef long long LL;

const int N = 500010;

int n, m;
LL w[N];
struct Node
{
    int l, r;
    LL sum, d;
}tr[N * 4];

LL gcd(LL a, LL b)
{
    return b ? gcd(b, a % b) : a;
}

void pushup(Node &u, Node &l, Node &r)
{
    u.sum = l.sum + r.sum;
    u.d = gcd(l.d, r.d);
}

void pushup(int u)
{
    pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}

void build(int u, int l, int r)
{
    if (l == r)
    {
        LL b = w[r] - w[r - 1];
        tr[u] = {l, r, b, b};
    }
    else
    {
        tr[u].l = l, tr[u].r = r;
        int mid = l + r >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}

void modify(int u, int x, LL v)
{
    if (tr[u].l == x && tr[u].r == x)
    {
        LL b = tr[u].sum + v;
        tr[u] = {x, x, b, b};
    }
    else
    {
        int mid = tr[u].l + tr[u].r >> 1;
        if (x <= mid) modify(u << 1, x, v);
        else modify(u << 1 | 1, x, v);
        pushup(u); // 向上传递给父节点
    }
}

Node query(int u, int l, int r)
{
    if (tr[u].l >= l && tr[u].r <= r) return tr[u];
    else
    {
        int mid = tr[u].l + tr[u].r >> 1;
        if (r <= mid) return query(u << 1, l, r);
        else if (l > mid) return query(u << 1 | 1, l, r);
        else
        {
            auto left = query(u << 1, l, r);
            auto right = query(u << 1 | 1, l, r);
            Node res;
            pushup(res, left, right); // 向上传递给父节点
            return res;
        }
    }
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i ++ ) scanf("%lld", &w[i]);
    build(1, 1, n);

    int l, r;
    LL d;
    char op[2];
    while (m -- )
    {
        scanf("%s%d%d", op, &l, &r);
        if (*op == 'Q')
        {
            auto left = query(1, 1, l);
            Node right({0, 0, 0, 0});
            if (l + 1 <= r) right = query(1, l + 1, r);
            printf("%lld\n", abs(gcd(left.sum, right.d)));
        }
        else
        {
            scanf("%lld", &d);
            modify(1, l, d);
            if (r + 1 <= n) modify(1, r + 1, -d);
        }
    }

    return 0;
}