fork(1) download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3.  
  4. const long long MOD = 998244353;
  5.  
  6. // Fast modular exponentiation (and inverse by Fermat’s little theorem)
  7. long long modExp(long long base, long long exp) {
  8. long long res = 1;
  9. while(exp > 0) {
  10. if(exp & 1) res = (res * base) % MOD;
  11. base = (base * base) % MOD;
  12. exp >>= 1;
  13. }
  14. return res;
  15. }
  16.  
  17. long long modInv(long long x) {
  18. return modExp(x, MOD - 2);
  19. }
  20.  
  21. // In our forest of probabilities, each vertex i has
  22. // F[i] = p[i]/q[i] (fall probability) and R[i] = 1 - F[i] (remain probability).
  23. // A vertex becomes a leaf if it remains and exactly one neighbor remains.
  24. // We then correct for dependencies when two vertices are “close.”
  25.  
  26. int main(){
  27. ios::sync_with_stdio(false);
  28. cin.tie(nullptr);
  29.  
  30. int t;
  31. cin >> t;
  32. while(t--){
  33. int n;
  34. cin >> n;
  35. vector<long long> p(n), q(n), F(n), R(n), L(n, 0);
  36. vector<vector<int>> graph(n);
  37.  
  38. for (int i = 0; i < n; i++){
  39. cin >> p[i] >> q[i];
  40. long long invq = modInv(q[i]);
  41. F[i] = (p[i] % MOD * invq) % MOD; // falling probability
  42. R[i] = (1 - F[i] + MOD) % MOD; // remaining probability
  43. }
  44.  
  45. for (int i = 0; i < n - 1; i++){
  46. int u, v;
  47. cin >> u >> v;
  48. u--; v--;
  49. graph[u].push_back(v);
  50. graph[v].push_back(u);
  51. }
  52.  
  53. // Precompute for each vertex the product over its neighbors of F[neighbor]
  54. vector<long long> prod(n, 1);
  55. for (int i = 0; i < n; i++){
  56. for (int nb : graph[i]){
  57. prod[i] = (prod[i] * F[nb]) % MOD;
  58. }
  59. }
  60.  
  61. // Compute L[i]: probability vertex i becomes a leaf.
  62. // If i has no neighbors, it can never be a leaf (by the problem’s definition).
  63. for (int i = 0; i < n; i++){
  64. if(graph[i].empty()){
  65. L[i] = 0;
  66. } else {
  67. long long sumNeighbor = 0;
  68. // For each neighbor j, the chance that j is the unique surviving neighbor:
  69. // Multiply by R[j] and “cancel” F[j] from the full product over neighbors.
  70. for (int nb : graph[i]){
  71. long long term = (R[nb] * prod[i]) % MOD;
  72. term = (term * modInv(F[nb])) % MOD;
  73. sumNeighbor = (sumNeighbor + term) % MOD;
  74. }
  75. L[i] = (R[i] * sumNeighbor) % MOD;
  76. }
  77. }
  78.  
  79. // S0 is the "naively independent" contribution:
  80. // S0 = 1/2 * [ (sum_i L[i])^2 - sum_i L[i]^2 ]
  81. long long sumL = 0, sumL2 = 0;
  82. for (int i = 0; i < n; i++){
  83. sumL = (sumL + L[i]) % MOD;
  84. sumL2 = (sumL2 + (L[i] * L[i]) % MOD) % MOD;
  85. }
  86. long long inv2 = modInv(2);
  87. long long S0 = (((sumL * sumL) % MOD - sumL2 + MOD) % MOD * inv2) % MOD;
  88.  
  89. // Correction for adjacent vertices:
  90. // For an edge (u,v), both become leaves only if they “choose each other.”
  91. long long sum_adj = 0;
  92. for (int u = 0; u < n; u++){
  93. for (int v : graph[u]){
  94. if(u < v){
  95. long long term = (R[u] * R[v]) % MOD;
  96. // In vertex u, the unique surviving neighbor must be v:
  97. long long part_u = (prod[u] * modInv(F[v])) % MOD;
  98. // And vice versa for vertex v.
  99. long long part_v = (prod[v] * modInv(F[u])) % MOD;
  100. term = (term * ((part_u * part_v) % MOD)) % MOD;
  101. term = (term - (L[u] * L[v]) % MOD + MOD) % MOD;
  102. sum_adj = (sum_adj + term) % MOD;
  103. }
  104. }
  105. }
  106.  
  107. // Correction for pairs of vertices at distance 2 (they share a common neighbor u).
  108. // We sum over each vertex u as the common neighbor.
  109. long long sum_dist2 = 0;
  110. for (int u = 0; u < n; u++){
  111. int d = graph[u].size();
  112. if(d < 2) continue;
  113.  
  114. // We want to compute, over unordered pairs (i, j) among neighbors of u, the sum
  115. // of δ^{(2)}_{ij} defined by:
  116. // δ^{(2)}_{ij} = R[i]*R[j]*( R[u]*(A_i*A_j)
  117. // + ((L[i] - R[i]*R[u]*A_i) * (L[j] - R[j]*R[u]*A_j) * inv(F[u]) ) )
  118. // - L[i]*L[j],
  119. // where A_i = prod[i] * inv(F[u]).
  120. //
  121. // We can sum these pairs in O(d) per u by precomputing sums.
  122.  
  123. long long invF_u = modInv(F[u]);
  124. long long invF_u2 = (invF_u * invF_u) % MOD;
  125.  
  126. long long sum_RA = 0, sum_RA2 = 0; // for term1: using X_i = R[i]*prod[i]
  127. long long sum_B = 0, sum_B2 = 0; // for term2: using B[i] = R[i]*(L[i] - R[i]*R[u]*(prod[i]*invF_u))
  128. long long sum_Lu = 0, sum_Lu2 = 0; // for term3: using L[i]
  129.  
  130. for (int i : graph[u]) {
  131. long long X = (R[i] * prod[i]) % MOD; // note: A_i will be X * invF_u
  132. sum_RA = (sum_RA + X) % MOD;
  133. sum_RA2 = (sum_RA2 + (X * X) % MOD) % MOD;
  134.  
  135. sum_Lu = (sum_Lu + L[i]) % MOD;
  136. sum_Lu2 = (sum_Lu2 + (L[i] * L[i]) % MOD) % MOD;
  137.  
  138. long long A = (prod[i] * invF_u) % MOD;
  139. long long diff = (L[i] - (R[i] * R[u]) % MOD * A % MOD + MOD) % MOD;
  140. long long B = (R[i] * diff) % MOD;
  141. sum_B = (sum_B + B) % MOD;
  142. sum_B2 = (sum_B2 + (B * B) % MOD) % MOD;
  143. }
  144.  
  145. // Term1: contribution from R[u]*A_i*A_j over pairs
  146. long long term1 = (R[u] * invF_u2) % MOD;
  147. term1 = (term1 * (((sum_RA * sum_RA) % MOD - sum_RA2 + MOD) % MOD)) % MOD;
  148. term1 = (term1 * inv2) % MOD;
  149.  
  150. // Term2: contribution from the difference part, multiplied by inv(F[u])
  151. long long term2 = (invF_u * (((sum_B * sum_B) % MOD - sum_B2 + MOD) % MOD)) % MOD;
  152. term2 = (term2 * inv2) % MOD;
  153.  
  154. // Term3: subtract the “naive” product over L-values
  155. long long term3 = (((sum_Lu * sum_Lu) % MOD - sum_Lu2 + MOD) % MOD * inv2) % MOD;
  156.  
  157. long long cur_u = (term1 + term2 - term3 + MOD) % MOD;
  158. sum_dist2 = (sum_dist2 + cur_u) % MOD;
  159. }
  160.  
  161. // Final answer is S0 plus 1/2 times the sum of adjacent and distance–2 corrections.
  162. long long finalAns = (S0 + inv2 * ((sum_adj + sum_dist2) % MOD)) % MOD;
  163. cout << finalAns % MOD << "\n";
  164. }
  165. return 0;
  166. }
Success #stdin #stdout 0.01s 5292KB
stdin
5
1
1 2
3
1 2
1 2
1 2
1 2
2 3
3
1 3
1 5
1 3
1 2
2 3
1
998244351 998244352
6
10 17
7 13
6 11
2 10
10 19
5 13
4 3
3 6
1 4
3 5
3 2
stdout
0
717488129
535354750
0
752877708