fork download
  1. #include <bits/stdc++.h>
  2. #define fi first
  3. #define se second
  4. #define all(v) v.begin() , v.end()
  5. #define sz(v) int(v.size())
  6. #define unq(v) sort(all(v)); v.resize(unique(all(v)) - v.begin());
  7. using namespace std;
  8.  
  9. typedef long long ll;
  10. typedef pair<int , int> ii;
  11. typedef pair<long long , int> lli;
  12.  
  13. const int maxN = int(2e5)+7;
  14. const int LOG = 20;
  15. const int maxK = 27;
  16. const int mod = 998244353;
  17.  
  18. int add(int x , int y){
  19. x += y;
  20. if (x >= mod) x -= mod;
  21. return x;
  22. }
  23.  
  24. void self_add(int &x , int y){
  25. x = add(x , y);
  26. }
  27.  
  28. int sub(int x , int y){
  29. x -= y;
  30. if (x < 0) x += mod;
  31. return x;
  32. }
  33.  
  34. void self_sub(int &x , int y){
  35. x = sub(x , y);
  36. }
  37.  
  38. int mul(int x , int y){
  39. return (1ll * x * y) % mod;
  40. }
  41.  
  42. int n , k , q , a[maxN] , tmp_a[maxN] , c[maxK];
  43. vector<int> g[maxN];
  44.  
  45. struct lca{
  46. int h[maxN] , up[maxN][LOG + 1] , sz[maxN] , cost[maxN];
  47. int in[maxN] , out[maxN] , DfsTime = 0;
  48.  
  49. void dfs(int u , int par){
  50. for (int i = 1 ; i <= LOG ; i++) up[u][i] = up[up[u][i - 1]][i - 1];
  51. sz[u] = 1;
  52. in[u] = ++DfsTime;
  53. for (int v : g[u]){
  54. if (v == par) continue;
  55. up[v][0] = u;
  56. h[v] = h[u] + 1;
  57. dfs(v , u);
  58. sz[u] += sz[v];
  59. }
  60. out[u] = DfsTime;
  61. }
  62.  
  63. int k_th(int u , int k){
  64. for (int i = LOG ; i >= 0 ; i--){
  65. if ((k>>i)&1) u = up[u][i];
  66. }
  67. return u;
  68. }
  69.  
  70. bool inside(int u , int v){
  71. return in[u] <= in[v] && out[v] <= out[u];
  72. }
  73.  
  74. int getNum(int u , int v){
  75. if (inside(u , v)){
  76. return mul(n - sz[k_th(v , h[v] - h[u] - 1)] , sz[v]);
  77. }
  78. if (inside(v , u)){
  79. return mul(n - sz[k_th(u , h[u] - h[v] - 1)] , sz[u]);
  80. }
  81. return mul(sz[u] , sz[v]);
  82. }
  83.  
  84. int getSize(int u , int v){
  85. if (u == v) return n;
  86. if (inside(u , v)){
  87. return sz[v];
  88. }
  89. if (inside(v , u)){
  90. return n - sz[k_th(u , h[u] - h[v] - 1)];
  91. }
  92. return sz[v];
  93. }
  94.  
  95. void build_cost(){
  96. for (int i = 1 ; i <= n ; i++){
  97. int pre = n - sz[i];
  98. cost[i] = n - 1;
  99. for (int j : g[i]){
  100. if (j != up[i][0]){
  101. self_add(cost[i] , mul(pre , sz[j]));
  102. pre += sz[j];
  103. }
  104. }
  105. }
  106. }
  107. } sigma;
  108.  
  109. int sz[maxN] , h[maxN] , p[maxN] , nxt_sz[LOG + 1][maxN] , num[LOG + 1][maxN];
  110. bool del[maxN];
  111.  
  112. void dfs_size(int u , int par){
  113. sz[u] = 1;
  114. for (int v : g[u]){
  115. if (v != par && del[v] == 0){
  116. dfs_size(v , u);
  117. sz[u] += sz[v];
  118. }
  119. }
  120. }
  121.  
  122. int dfs_find(int u , int par , int half){
  123. for (int v : g[u]){
  124. if (v != par && del[v] == 0 && sz[v] > half){
  125. return dfs_find(v , u , half);
  126. }
  127. }
  128. return u;
  129. }
  130.  
  131. void dfs_prepare(int x , int u , int par , int d){
  132. nxt_sz[d][u] = sigma.getSize(x , u);
  133. num[d][u] = sigma.getNum(x , u);
  134. for (int v : g[u]){
  135. if (v != par && del[v] == 0){
  136. dfs_prepare(x , v , u , d);
  137. }
  138. }
  139. }
  140.  
  141. int id[maxN] , numID = 0;
  142.  
  143. void dfs_build(int u , int par , int d){
  144. dfs_size(u , 0);
  145. u = dfs_find(u , 0 , sz[u] / 2);
  146. del[u] = 1;
  147. h[u] = d;
  148. p[u] = par;
  149. id[u] = ++numID;
  150. dfs_prepare(u , u , par , d);
  151. for (int v : g[u]){
  152. if (del[v] == 0){
  153. dfs_build(v , u , d + 1);
  154. }
  155. }
  156. }
  157.  
  158. int ans = 0 , sum[maxK][maxN] , cur[maxK][maxN] , sum_sub[maxK][maxN];
  159.  
  160. void update(int u , int t){
  161. if (t == +1){
  162. self_add(ans , cur[a[u]][u]);
  163. self_add(ans , mul(sigma.cost[u] , c[a[u]]));
  164. }
  165. else{
  166. self_sub(ans , cur[a[u]][u]);
  167. self_sub(ans , mul(sigma.cost[u] , c[a[u]]));
  168. }
  169. for (int x = p[u] , y = u ; x != 0 ; x = p[x] , y = p[y]){
  170. if (t == +1){
  171. self_add(sum[a[u]][x] , nxt_sz[h[x]][u]);
  172. self_add(sum_sub[a[u]][id[y]] , nxt_sz[h[x]][u]);
  173. self_add(ans , mul(mul(2 , c[a[u]]) , mul(sub(sum[a[u]][x] , sum_sub[a[u]][id[y]]) , nxt_sz[h[x]][u])));
  174. self_add(cur[a[u]][x] , mul(mul(2 , c[a[u]]) , num[h[x]][u]));
  175. if (a[x] == a[u]){
  176. self_add(ans , mul(mul(2 , c[a[u]]) , num[h[x]][u]));
  177. }
  178. }
  179. else{
  180. self_sub(sum[a[u]][x] , nxt_sz[h[x]][u]);
  181. self_sub(sum_sub[a[u]][id[y]] , nxt_sz[h[x]][u]);
  182. self_sub(ans , mul(mul(2 , c[a[u]]) , mul(sub(sum[a[u]][x] , sum_sub[a[u]][id[y]]) , nxt_sz[h[x]][u])));
  183. self_sub(cur[a[u]][x] , mul(mul(2 , c[a[u]]) , num[h[x]][u]));
  184. if (a[x] == a[u]){
  185. self_sub(ans , mul(mul(2 , c[a[u]]) , num[h[x]][u]));
  186. }
  187. }
  188. }
  189. }
  190.  
  191. void solve(){
  192. cin >> n >> k >> q;
  193. for (int i = 1 ; i <= n ; i++){
  194. cin >> tmp_a[i];
  195. }
  196. for (int i = 1 ; i <= k ; i++) cin >> c[i];
  197. for (int i = 1 ; i < n ; i++){
  198. int u , v;
  199. cin >> u >> v;
  200. g[u].push_back(v);
  201. g[v].push_back(u);
  202. }
  203. sigma.dfs(1 , 0);
  204. sigma.build_cost();
  205. dfs_build(1 , 0 , 0);
  206. //return;
  207. for (int i = 1 ; i <= n ; i++){
  208. a[i] = tmp_a[i];
  209. update(i , +1);
  210. }
  211. cout << ans << "\n";
  212. while (q--){
  213. int u , x;
  214. cin >> u >> x;
  215. update(u , -1);
  216. a[u] = x;
  217. update(u , +1);
  218. cout << ans << "\n";
  219. }
  220. }
  221.  
  222. #define name "fbuy"
  223.  
  224. int main(){
  225. ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
  226. if (fopen(name".INP" , "r")){
  227. freopen(name".INP" , "r" , stdin);
  228. freopen(name".OUT" , "w" , stdout);
  229. }
  230. int t = 1; //cin >> t;
  231. while (t--) solve();
  232. return 0;
  233. }
  234.  
Success #stdin #stdout 0.01s 17908KB
stdin
Standard input is empty
stdout
0