合约数(DFS序+可持久化线段树)

描述

传送门:埃森哲杯第十六届上海大学程序设计联赛春季赛暨上海高校金马五校赛 - Problem D

给定一棵$n$个节点的树,并且根节点的编号为$p$,第$i$个节点有属性值$val_i$, 定义$F(i)$: 在以$i$为根的子树中,属性值是$val_i$的合约数的节点个数。$y$ 是 $x$ 的合约数是指 $y$ 是合数且 $y$ 是 $x$ 的约数。小埃想知道$\sum_\limits{i=1}^n i · F(i)$对$1000000007$取模后的结果。

输入

输入测试组数$T$。
每组数据,输入$n+1$行整数,第一行为$n$和$p$,$1≤n≤20000$,$1≤p≤n$。
接下来$n−1$行,每行两个整数$u$和$v$,表示$u$和$v$之间有一条边。
第$n+1$行输入$n$个整数$val_1, val_2,…, val_n$,其中$1≤val_i≤10000$,$1≤i≤n$。

输出

对于每组数据,输出一行,包含1个整数表示$\sum_\limits{i=1}^n i · F(i)$对$1000000007$取模后的结果

样例输入

1
2
3
4
5
6
7
8
9
10
11
2
5 4
5 3
2 5
4 2
1 3
10 4 3 10 5
3 3
1 3
2 1
1 10 1

样例输出

1
2
11
2

思路

  • 看到子树,一般会想到dfs序。也就是说,$F(i)$就是在一个区间内,$val_i$的合约数的出现次数之和。
  • 一个数的合约数可以通过筛法预处理得到。
  • 只需要根据dfs序用一个可持久化数据结构维护dfs序列中前$x$个数每个数的出现次数即可。
  • 这样对于每个点,我们只需要枚举它的合约数,将两个查询的结果减一下就得到了$F(i)$。
  • 用可持久化线段树就好了。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include <bits/stdc++.h>
using namespace std;
#define clr(a, x) memset(a, x, sizeof(a))
#define mp(x, y) make_pair(x, y)
#define pb(x) push_back(x)
#define X first
#define Y second
#define fastin \
ios_base::sync_with_stdio(0); \
cin.tie(0);
typedef long long ll;
typedef long double ld;
typedef pair<int, int> PII;
typedef vector<int> VI;
const int INF = 0x3f3f3f3f;
const int mod = 1e9 + 7;
const double eps = 1e-6;

const int N = 1 << 15;
vector<int> G[N];
int v[N], rt[N];
vector<int> y[N];
bool vis[N];
int dfn, tot;

inline void init()
{
for (int i = 2; i < N; i++)
{
if (!vis[i])
for (int j = i + i; j < N; j += i)
vis[j] = 1;
else
for (int j = i; j < N; j += i)
y[j].push_back(i);
}
}
struct Node
{
int l, r, cnt;
} t[N << 6];

void update(int& x, int p, int l, int r)
{
t[++tot] = t[x], x = tot;
if (l == r)
{
t[x].cnt++;
return;
}
int mid = l + r >> 1;
if (p <= mid)
update(t[x].l, p, l, mid);
else
update(t[x].r, p, mid + 1, r);
}

int query(int x, int p, int l, int r)
{
if (!x) return 0;
if (l == r) return t[x].cnt;
int mid = l + r >> 1;
if (p <= mid)
return query(t[x].l, p, l, mid);
else
return query(t[x].r, p, mid + 1, r);
}

int L[N], R[N];
void dfs(int u, int fa)
{
L[u] = ++dfn;
rt[dfn] = rt[dfn - 1];
update(rt[dfn], v[u], 1, 10000);
for (auto& v : G[u])
if (v != fa) dfs(v, u);
R[u] = dfn;
}

int main()
{
#ifndef ONLINE_JUDGE
freopen("1.in", "r", stdin);
freopen("1.out", "w", stdout);
#endif
init();
int T;
scanf("%d", &T);
while (T--)
{
int n, p;
scanf("%d%d", &n, &p);
for (int i = 1; i <= n; i++) vector<int>().swap(G[i]);
for (int i = 1; i < n; i++)
{
static int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v), G[v].push_back(u);
}
for (int i = 1; i <= n; i++) scanf("%d", &v[i]);
dfn = tot = 0;
dfs(p, -1);
ll ans = 0;
for (int i = 1; i <= n; i++)
{
ll tmp = 0;
for (auto& c : y[v[i]])
tmp += query(rt[R[i]], c, 1, 10000) - query(rt[L[i] - 1], c, 1, 10000);
(ans += tmp * i) %= mod;
}
printf("%lld\n", ans);
}
return 0;
}
捐助作者
0%