fork download
  1. /* package whatever; // don't place package name! */
  2.  
  3. import java.util.*;
  4. import java.lang.*;
  5. import java.io.*;
  6.  
  7. /* Name of the class has to be "Main" only if the class is public. */
  8. class Main
  9. {
  10. public static int dfs(int node, ArrayList<Integer>[] adj, int[] vis){
  11. vis[node] = 1;
  12. int cnt = 1;
  13. for(int child: adj[node]){
  14. if(vis[child] == 0){
  15. cnt += dfs(child, adj, vis);
  16. }
  17. }
  18. return cnt;
  19. }
  20. public static int findRoute(int[][] edges, int[] col, int n){
  21. ArrayList<Integer>[] adj = new ArrayList[n];
  22. for (int i = 0; i < n; i++){
  23. adj[i] = new ArrayList<>();
  24. }
  25. for(int i=0;i<edges.length;i++){
  26. int x = edges[i][0];
  27. int y = edges[i][1];
  28. if(col[x] == col[y]){
  29. adj[x].add(y);
  30. adj[y].add(x);
  31. }
  32. }
  33.  
  34. int cnt = 0;
  35. int[] vis = new int[n];
  36. for(int i=0;i<n;i++){
  37. if(vis[i] == 0){
  38. int cur_cnt = dfs(i, adj, vis);
  39. cnt += cur_cnt;
  40. }
  41. }
  42. return cnt;
  43. }
  44. public static void main (String[] args) throws java.lang.Exception
  45. {
  46. // your code goes here
  47. Scanner sc = new Scanner(System.in);
  48. int n = sc.nextInt();
  49. int[][] edges = new int[n-1][2];
  50. for(int i=0;i<n-1;i++){
  51. edges[i][0] = sc.nextInt() - 1;
  52. edges[i][1] = sc.nextInt() - 1;
  53. }
  54.  
  55. int[] col = new int[n];
  56. for(int i=0;i<n;i++){
  57. col[i] = sc.nextInt();
  58. }
  59. System.out.println(findRoute(edges, col, n));
  60. }
  61. }
Success #stdin #stdout 0.18s 56660KB
stdin
6
1 2
2 3
2 4
1 5
5 6
0 1 1 1 0 0
stdout
6