# 【hdu 5877】Weak Pair (3种解法)

2016-09-13

【hdu 5877】Weak Pair (3种解法)

Time Limit: 4000/2000 MS (Java/Others)Memory Limit: 262144/262144 K (Java/Others)

Total Submission(s): 1637Accepted Submission(s): 531

Problem Description You are given arootedtree ofNnodes, labeled from 1 toN. To theith node a non-negative valueaiis assigned.Anorderedpair of nodes(u,v)is said to beweakif

(1)uis an ancestor ofv(Note: In this problem a nodeuis not considered an ancestor of itself);

(2)au&times;av&le;k.

Can you find the number of weak pairs in the tree?

Input There are multiple cases in the data set.

The first line of input contains an integerTdenoting number of test cases.

For each case, the first line contains two space-separated integers,Nandk, respectively.

The second line containsNspace-separated integers, denotinga1toaN.

Each of the subsequent lines contains two space-separated integers defining an edge connecting nodesuandv, where nodeuis the parent of nodev.

Constrains:

1&le;N&le;105

0&le;ai&le;109

0&le;k&le;1018

Output For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.

Sample Input

```1 2 3 1 2 1 2
```

Sample Output

```1
```

【题解】

void dfs(int x)

{

dfs(儿子);

}

1.平衡树

(这里用的是SBT)

2.树状数组

seq[1..n]存a[1..n]

3.线段树

【代码1】平衡树

```#include
#include
#include

const int MAXN = 100009;

using namespace std;

int T, n, father[MAXN], root_tree, root;
int si_ze[MAXN * 2], l[MAXN * 2], r[MAXN * 2], totn = 0;
__int64 key[MAXN * 2], k, seq[MAXN], ans = 0;
vector  son[MAXN];

void init()
{
memset(father, 0, sizeof(father));
for (int i = 1; i <= 100000; i++)
son[i].clear();
root = 0;
totn = 0;
ans = 0;
}

void input_data()
{
scanf("%d%I64d", &n, &k);
for (int i = 1; i <= n; i++)
scanf("%I64d", &seq[i]);
int u, v;
for (int i = 1; i <= n - 1; i++)
{
scanf("%d%d", &u, &v);
father[v] = u;
son[u].push_back(v);
}
for (int i = 1; i <= n; i++)//找根节点
if (father[i] == 0)
{
root_tree = i;
break;
}
}

void right_rotation(int &t)//右旋
{
int k = l[t];
l[t] = r[k];
r[k] = t;
si_ze[k] = si_ze[t];
si_ze[t] = si_ze[l[t]] + si_ze[r[t]] + 1;
t = k;
}

void left_rotation(int &t) //左旋
{
int k = r[t];
r[t] = l[k];
l[k] = t;
si_ze[k] = si_ze[t];
si_ze[t] = si_ze[l[t]] + si_ze[r[t]] + 1;
t = k;
}

void maintain(int &t, bool flag)
{
if (flag)
{
if (si_ze[l[l[t]]] > si_ze[r[t]]) // 这是/型
right_rotation(t);
else
if (si_ze[r[l[t]]] > si_ze[r[t]])//....
{
left_rotation(l[t]);
right_rotation(t);
}
else
return;//是平衡的就结束
}
else
{
if (si_ze[r[r[t]]] > si_ze[l[t]])
left_rotation(t);
else
if (si_ze[l[r[t]]] > si_ze[l[t]])
{
right_rotation(r[t]);
left_rotation(t);
}
else
return;
}
maintain(l[t], true);
maintain(r[t], false);
maintain(t, true);
maintain(t, false);
}

void insert(int &t, __int64 data2)
{
if (t == 0)
{
t = ++totn;
l[t] = r[t] = 0;
si_ze[t] = 1;
key[t] = data2;
}
else
{
si_ze[t]++;
if (data2 < key[t])
insert(l[t], data2);
else
insert(r[t], data2);
maintain(t, data2 < key[t]);
}
}

void de_lete(__int64 data2, int &t)
{
si_ze[t]--;
if (data2 == key[t])
{
bool flag1 = true, flag2 = true;//用于判断有无左右子树。
if (l[t] == 0)
flag1 = false;
if (r[t] == 0)
flag2 = false;
if (!flag1 && !flag2)
t = 0;
else
if (!flag1 && flag2)
t = r[t];
else
if (flag1 && !flag2)
t = l[t];
else
if (flag1 && flag2)
{
int temp = r[t];
while (l[temp])
temp = l[temp];
key[t] = key[temp];
de_lete(key[temp], r[t]);
}
}
else
if (data2 < key[t])
de_lete(data2, l[t]);
else
de_lete(data2, r[t]);
}

__int64 find(__int64 what, int &t)
{
if (t == 0)
return 0;
__int64 temp = 0;
if (what >= key[t]) //左子树加上这个节点都小于等于what满足要求
{
temp += si_ze[l[t]] + 1;
temp += find(what, r[t]);
}
else
return find(what, l[t]);
return temp;
}

void dfs(int rt)
{
__int64 temp = k / seq[rt];
if (seq[rt] == 0)
temp = 2100000000;//这里temp本应改成1e19的，但是貌似也对了。。
ans += find(temp, root);
insert(root, seq[rt]);
int len = son[rt].size();
for (int i = 0; i <= len - 1; i++)
dfs(son[rt][i]);
de_lete(seq[rt], root);
}

void output_ans()
{
printf("%I64d\n", ans);
}

int main()
{
//freopen("F:\\rush.txt", "r", stdin);
//freopen("F:\\rush_out.txt", "w", stdout);
scanf("%d", &T);
while (T--)
{
init();
input_data();
dfs(root_tree);
output_ans();
}
return 0;
}```

【代码2】树状数组

```#include
#include
#include

using namespace std;

const int MAXN = 101000;
const long long INF = 10000000000000000000;

int c[MAXN * 2], a[MAXN], father[MAXN], n, lena;
long long seq[MAXN * 2], ans = 0, k;
vector  son[MAXN];

void init()
{
ans = 0;
memset(c, 0, sizeof(c));
memset(father, 0, sizeof(father));
for (int i = 1; i <= 100010; i++)
son[i].clear();
}

void input_data()
{
scanf("%d%I64d", &n, &k);
for (int i = 1; i <= n; i++)
{
scanf("%d", &a[i]);
seq[i] = a[i];
}
for (int i = n + 1; i <= 2 * n; i++)
{
long long temp = a[i - n];
if (a[i - n] == 0) //除0就改为无限大
seq[i] = INF;
else
seq[i] = k / temp;
}
sort(seq + 1, seq + 1 + n * 2);
//可能有重复的，但没关系,lower_bound返回的是一个特定值(seq是不会发生变化的)
lena = 2 * n;
}

int lowbit(int x)
{
return x & -x;
}

void dfs(int x)
{
int key;
if (a[x] == 0)
key = lower_bound(seq + 1, seq + 1 + 2 * n, INF) - seq;
else
key = lower_bound(seq + 1, seq + 1 + 2 * n, k / a[x]) - seq;
long long temp2 = 0;
while (key)
{
temp2 += c[key];
key = key - lowbit(key);
}
ans += temp2;
key = lower_bound(seq + 1, seq + 1 + 2 * n, a[x]) - seq;
int temp = key;
while (temp <= lena)
{
c[temp]++;
temp += lowbit(temp);
}
int len = son[x].size();
for (int i = 0; i <= len - 1; i++)
dfs(son[x][i]);
while (key <= lena)
{
c[key]--;
key += lowbit(key);
}
}

void get_ans()
{
int root = 1;
for (int i = 1; i <= n - 1; i++)
{
int u, v;
scanf("%d %d", &u, &v);
father[v] = u;
son[u].push_back(v);
}
for (int i = 1; i <= n; i++)
if (father[i] == 0)
{
root = i;
}
dfs(root);
}

void output_ans()
{
printf("%I64d\n", ans);
}

int main()
{
//freopen("F:\\rush.txt", "r", stdin);
//freopen("F:\\rush_out.txt", "w", stdout);
int t;
scanf("%d", &t);
while (t--)
{
init();
input_data();
get_ans();
output_ans();
}
return 0;
}```

【代码3】线段树，基础的单点递增

```#include
#include
#include
#define lson begin, m, rt << 1
#define rson m+1,end,rt<<1|1

using namespace std;

const int MAXN = 101000;
const long long INF = 10000000000000000000;

int  a[MAXN], father[MAXN], n, lena, sum[MAXN * 2 * 4]; //记住要开4倍
long long seq[MAXN * 2], ans = 0, k;
vector  son[MAXN];

void build(int begin, int end, int rt)
{
sum[rt] = 0;
if (begin == end)
return;
int m = (begin + end) >> 1;
build(lson);
build(rson);
}

void init()
{
ans = 0;
memset(father, 0, sizeof(father));
for (int i = 1; i <= 100010; i++)
son[i].clear();
}

void input_data()
{
scanf("%d%I64d", &n, &k);
for (int i = 1; i <= n; i++)
{
scanf("%d", &a[i]);
seq[i] = a[i];
}
for (int i = n + 1; i <= 2 * n; i++)
{
long long temp = a[i - n];
if (a[i - n] == 0)
seq[i] = INF;
else
seq[i] = k / temp;
}
sort(seq + 1, seq + 1 + n * 2);
lena = 2 * n;
build(1, lena, 1);
}

long long query(int l, int r, int begin, int end, int rt)//求区间和。
{
if (l <= begin && end <= r)
return sum[rt];
int m = (begin + end) >> 1;
long long temp = 0;
if (l <= m)
temp += query(l, r, lson);
if (m < r)
temp += query(l, r, rson);
return temp;
}

void push_up(int rt)
{
sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
}

void updata(int pos, int num, int begin, int end, int rt)
{
if (begin == end)
{
sum[rt] += num;
return;
}
int m = (begin + end) >> 1;
if (pos <= m)
updata(pos, num, lson);
else
updata(pos, num, rson);
push_up(rt);
}

void dfs(int x)
{
int key;
if (a[x] == 0)
key = lower_bound(seq + 1, seq + 1 + 2 * n, INF) - seq;
else
key = lower_bound(seq + 1, seq + 1 + 2 * n, k / a[x]) - seq;
long long temp2 = query(1, key, 1, lena, 1);
ans += temp2;
key = lower_bound(seq + 1, seq + 1 + 2 * n, a[x]) - seq;
updata(key, 1, 1, lena, 1);
int len = son[x].size();
for (int i = 0; i <= len - 1; i++)
dfs(son[x][i]);
updata(key, -1, 1, lena, 1);
}

void get_ans()
{
int root = 1;
for (int i = 1; i <= n - 1; i++)
{
int u, v;
scanf("%d %d", &u, &v);
father[v] = u;
son[u].push_back(v);
}
for (int i = 1; i <= n; i++)
if (father[i] == 0)
{
root = i;
}
dfs(root);
}

void output_ans()
{
printf("%I64d\n", ans);
}

int main()
{
//freopen("F:\\rush.txt", "r", stdin);
//freopen("F:\\rush_out.txt", "w", stdout);
int t;
scanf("%d", &t);
while (t--)
{
init();
input_data();
get_ans();
output_ans();
}
return 0;
}```