fork download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3.  
  4. vector<vector<int>> buildGraph(int n, const vector<pair<int, int>>& edges) {
  5. vector<vector<int>> g(n + 1);
  6. for (auto [u, v] : edges) {
  7. g[u].push_back(v);
  8. g[v].push_back(u);
  9. }
  10. return g;
  11. }
  12.  
  13. vector<int> getParents(int root, const vector<vector<int>>& g) {
  14. int n = g.size() - 1;
  15. vector<int> par(n + 1, -1);
  16. queue<int> q;
  17. q.push(root);
  18. par[root] = 0;
  19.  
  20. while (!q.empty()) {
  21. int u = q.front();
  22. q.pop();
  23. for (int v : g[u]) {
  24. if (par[v] == -1) {
  25. par[v] = u;
  26. q.push(v);
  27. }
  28. }
  29. }
  30. return par;
  31. }
  32.  
  33. set<int> getSteinerNodes(int root, const vector<int>& targets, const vector<int>& par) {
  34. set<int> nodes;
  35. nodes.insert(root);
  36.  
  37. for (int t : targets) {
  38. while (t != 0) {
  39. nodes.insert(t);
  40. if (t == root) break;
  41. t = par[t];
  42. }
  43. }
  44. return nodes;
  45. }
  46.  
  47. int main() {
  48. int n = 8;
  49. vector<int> targets = {5, 6};
  50. vector<pair<int, int>> edges = {
  51. {1, 2}, {2, 5}, {2, 3}, {2, 6},
  52. {1, 4}, {4, 7}, {7, 8}
  53. };
  54.  
  55. auto g = buildGraph(n, edges);
  56. auto par = getParents(1, g);
  57. auto steinerNodes = getSteinerNodes(1, targets, par);
  58.  
  59. // Each edge traversed twice (down and back)
  60. int result = 2 * (steinerNodes.size() - 1);
  61. cout << result << endl;
  62.  
  63. return 0;
  64. }
Success #stdin #stdout 0.01s 5292KB
stdin
Standard input is empty
stdout
6