fork download
  1. from itertools import combinations
  2. from io import BytesIO # 在Python 2.7中使用BytesIO代替StringIO
  3. import csv
  4.  
  5. class Apriori:
  6. def __init__(self, min_support=0.01, min_confidence=0.5):
  7. self.min_support = min_support
  8. self.min_confidence = min_confidence
  9. self.frequent_itemsets = None
  10. self.association_rules = None
  11.  
  12. def fit(self, transactions):
  13. """执行Apriori算法"""
  14. items = self._get_unique_items(transactions)
  15. frequent_itemsets_1 = self._get_frequent_itemsets(transactions, items, 1)
  16. frequent_itemsets = [frequent_itemsets_1]
  17.  
  18. k = 2
  19. while True:
  20. candidates = self._generate_candidates(frequent_itemsets[k-2], k)
  21. if not candidates:
  22. break
  23.  
  24. freq_itemsets_k = self._get_frequent_itemsets(transactions, candidates, k)
  25. if not freq_itemsets_k:
  26. break
  27.  
  28. frequent_itemsets.append(freq_itemsets_k)
  29. k += 1
  30.  
  31. self.frequent_itemsets = frequent_itemsets
  32. return self
  33.  
  34. def generate_rules(self):
  35. """生成关联规则"""
  36. if self.frequent_itemsets is None:
  37. raise ValueError("请先执行fit方法")
  38.  
  39. rules = []
  40. for itemset_level in self.frequent_itemsets[1:]:
  41. for itemset, support in itemset_level.items():
  42. itemset_list = list(itemset)
  43. for i in range(1, len(itemset_list)):
  44. for antecedent in combinations(itemset_list, i):
  45. antecedent = frozenset(antecedent)
  46. consequent = itemset - antecedent
  47. confidence = support / self.frequent_itemsets[len(antecedent)-1][antecedent]
  48. if confidence >= self.min_confidence:
  49. rules.append((antecedent, consequent, support, confidence))
  50.  
  51. self.association_rules = rules
  52. return rules
  53.  
  54. def _get_unique_items(self, transactions):
  55. """获取所有唯一的项"""
  56. items = set()
  57. for transaction in transactions:
  58. for item in transaction:
  59. items.add(item)
  60. return [frozenset([item]) for item in items]
  61.  
  62. def _get_frequent_itemsets(self, transactions, candidates, k):
  63. """计算候选项集的支持度并筛选频繁项集"""
  64. itemset_counts = {}
  65. for transaction in transactions:
  66. transaction_set = set(transaction)
  67. for candidate in candidates:
  68. if candidate.issubset(transaction_set):
  69. itemset_counts[candidate] = itemset_counts.get(candidate, 0) + 1
  70.  
  71. num_transactions = len(transactions)
  72. frequent_itemsets = {}
  73. for itemset, count in itemset_counts.items():
  74. support = float(count) / num_transactions
  75. if support >= self.min_support:
  76. frequent_itemsets[itemset] = support
  77.  
  78. return frequent_itemsets
  79.  
  80. def _generate_candidates(self, frequent_itemsets_prev, k):
  81. """生成候选k项集"""
  82. candidates = set()
  83. items = list(frequent_itemsets_prev.keys())
  84.  
  85. for i in range(len(items)):
  86. for j in range(i+1, len(items)):
  87. union_set = items[i].union(items[j])
  88. if len(union_set) == k:
  89. subsets = combinations(union_set, k-1)
  90. if all(frozenset(subset) in frequent_itemsets_prev for subset in subsets):
  91. candidates.add(union_set)
  92.  
  93. return candidates
  94.  
  95. def main():
  96. # 模拟CSV数据(使用普通字符串)
  97. csv_data = """milk,bread,butter
  98. bread,eggs
  99. milk,bread,eggs
  100. milk,eggs
  101. bread,butter
  102. milk,bread,butter,eggs
  103. milk,bread
  104. bread,eggs,butter
  105. milk,eggs,butter
  106. bread,butter"""
  107.  
  108. # Python 2.7兼容方案:使用BytesIO + 手动解码
  109. transactions = []
  110. for line in csv_data.split('\n'):
  111. transactions.append(line.strip().split(','))
  112.  
  113. # 初始化Apriori算法
  114. apriori = Apriori(min_support=0.3, min_confidence=0.7)
  115.  
  116. # 执行算法
  117. apriori.fit(transactions)
  118. rules = apriori.generate_rules()
  119.  
  120. # 打印结果(Python 2.7使用print语句)
  121. print "频繁项集:"
  122. for level, itemsets in enumerate(apriori.frequent_itemsets, 1):
  123. print "\n%d-项集 (数量: %d)" % (level, len(itemsets))
  124. for itemset, support in itemsets.items():
  125. print "%s: 支持度=%.4f" % (tuple(itemset), support)
  126.  
  127. print "\n关联规则:"
  128. for rule in rules:
  129. antecedent, consequent, support, confidence = rule
  130. print "%s => %s: 支持度=%.4f, 置信度=%.4f" % (
  131. tuple(antecedent), tuple(consequent), support, confidence)
  132.  
  133. if __name__ == "__main__":
  134. main()
Success #stdin #stdout 0.02s 7724KB
stdin
Standard input is empty
stdout
频繁项集:

1-项集 (数量: 4)
('bread',): 支持度=0.8000
('milk',): 支持度=0.6000
('eggs',): 支持度=0.6000
('butter',): 支持度=0.6000

2-项集 (数量: 6)
('milk', 'bread'): 支持度=0.4000
('butter', 'eggs'): 支持度=0.3000
('butter', 'bread'): 支持度=0.5000
('butter', 'milk'): 支持度=0.3000
('eggs', 'milk'): 支持度=0.4000
('eggs', 'bread'): 支持度=0.4000

关联规则:
('butter',) => ('bread',): 支持度=0.5000, 置信度=0.8333