#include <iostream>
#include <unordered_map>
using namespace std;
struct TreeNode {
int val;
TreeNode *left;
TreeNode *right;
TreeNode(int x) : val(x), left(NULL), right(NULL) {}
};
int countIslands(TreeNode* node, unordered_map<TreeNode*, int>& dp) {
if (node == nullptr) {
return 0;
}
int leftIslands = countIslands(node->left, dp);
int rightIslands = countIslands(node->right, dp);
if (node->val == 0) {
dp[node] = leftIslands + rightIslands;
} else {
if (node->left != nullptr && node->left->val == 1 && node->right != nullptr && node->right->val == 1) {
dp[node] = leftIslands + rightIslands - 1; // Merging two islands
} else if (node->left != nullptr && node->left->val == 1) {
dp[node] = leftIslands + rightIslands; // Only left child is 1
} else if (node->right != nullptr && node->right->val == 1) {
dp[node] = leftIslands + rightIslands; // Only right child is 1
} else {
dp[node] = 1 + leftIslands + rightIslands;
}
}
return dp[node];
}
int findIslands(TreeNode* root) {
unordered_map<TreeNode*, int> dp;
return countIslands(root, dp);
}
int main() {
TreeNode* root = new TreeNode(0);
root->left = new TreeNode(1);
root->right = new TreeNode(1);
root->left->left = new TreeNode(0);
root->left->right = new TreeNode(1);
root->left->left->left = new TreeNode(1);
root->left->left->right = new TreeNode(1);
root->right->left = new TreeNode(1);
root->right->right = new TreeNode(0);
cout << "Number of islands: " << findIslands(root) << endl;
return 0;
}