fork download
  1. #include <bits/stdc++.h>
  2. #define ll long long int
  3. #define ld long double
  4. #define mk make_pair
  5. #define pb push_back
  6. #define INF (ll)1e18
  7. #define pll pair<ll, ll>
  8. #define vll vector<ll>
  9. #define vii vector<int>
  10. #define all(x) x.begin(), x.end()
  11. #define allr(x) x.rbegin(), x.rend()
  12. #define dpp(arr, val) memset(arr, val, sizeof(arr))
  13. #define isON(n, k) (((n) >> (k)) & 1)
  14. #define f(i, a, b) for (ll i = a; i < b; i++)
  15. #define fb(i, a, b) for (ll i = a; i > b; i--)
  16. #define ff first
  17. #define ss second
  18. #define PI 3.141592653589793238
  19. // int dx[8] = { 1, 0, -1, 0, -1, 1, -1, 1 };
  20. // int dy[8] = { 0, 1, 0, -1, -1, 1, 1, -1 };
  21. using namespace std;
  22. const int N = 2e5 + 5;
  23. const ll MOD = 1e9 + 7;
  24. ll t = 1;
  25. int n, l;
  26. vector<vector<ll>> adj;
  27. vll a;
  28.  
  29. vector<int> depth;
  30. vector<bitset<251>> ans;
  31. vector<vector<int>> up;
  32.  
  33. void dfs(int v, int p)
  34. {
  35. depth[v] = (v == p) ? 0 : depth[p] + 1;
  36. up[v][0] = p;
  37.  
  38. for (int i = 1; i <= l; ++i)
  39. up[v][i] = up[up[v][i - 1]][i - 1];
  40.  
  41. for (int u : adj[v])
  42. {
  43. if (u != p){
  44. dfs(u, v);
  45. ans[v] |= ans[u];
  46. }
  47. }
  48. }
  49.  
  50. int lca(int u, int v)
  51. {
  52. if (depth[u] < depth[v])
  53. swap(u, v);
  54. int k = depth[u] - depth[v];
  55. for (int j = 0; j < l; j++)
  56. {
  57. if (k & (1LL << j))
  58. {
  59. u = up[u][j];
  60. }
  61. }
  62. if (u == v)
  63. return u;
  64. for (int j = l; j >= 0; j--)
  65. {
  66. if (up[u][j] != up[v][j])
  67. {
  68. u = up[u][j];
  69. v = up[v][j];
  70. }
  71. }
  72. return up[u][0];
  73. }
  74.  
  75. void preprocess(int root)
  76. {
  77. depth.resize(n + 1);
  78. l = ceil(log2(n)) + 1;
  79. up.assign(n + 1, vector<int>(l + 1));
  80. dfs(root, root);
  81. }
  82.  
  83. void solve()
  84. {
  85. ll q, root;
  86. cin >> n >> q >> root;
  87. ans = vector<bitset<251>>(n + 1);
  88. adj = vector<vll>(n + 1);
  89. a = vll(n + 1);
  90.  
  91. f(i, 0, n) {
  92. cin >> a[i];
  93. ans[i].reset();
  94. ans[i].set(a[i]);
  95. }
  96. f(i, 0, n - 1)
  97. {
  98. ll x, y;
  99. cin >> x >> y;
  100. adj[x].pb(y);
  101. adj[y].pb(x);
  102. }
  103. preprocess(root);
  104. while (q--)
  105. {
  106. ll x, y;
  107. cin >> x >> y;
  108. x = lca(x, y);
  109. cout << ans[x].count() << endl;
  110. }
  111.  
  112. return;
  113. }
  114.  
  115. int main()
  116. {
  117. ios_base::sync_with_stdio(0);
  118. cin.tie(0);
  119. cout.tie(0);
  120.  
  121. cin >> t;
  122. while (t--)
  123. {
  124. solve();
  125. }
  126.  
  127. return 0;
  128. }
Success #stdin #stdout 0.01s 5300KB
stdin
1
10 7 1
1 2 1 4 5 6 6 8 9 9
0 9
9 3
3 4
4 6
4 5
4 8
1 3
1 2
2 7
4 4
8 6
0 6
7 0
7 2
0 0
2 3
stdout
3
3
5
7
2
1
7