跳到主要内容从零开始实现决策树——手撕 CART 算法(C++) | 极客日志C++AI算法
从零开始实现决策树——手撕 CART 算法(C++)
基于 C++20 标准从零实现 CART 决策树算法,包含分类树与回归树两部分。详细阐述了数据结构设计、训练过程(基尼指数与平方误差最小化)、预测方法及代价复杂度剪枝(CCP)的实现细节。同时列出了离散值处理、缺失值及多线程适配等实际应用中的注意事项。
赛博朋克6 浏览 本文的 C++ 代码基于 C++ 20 标准(不包含 C++ modules),对于之前的标准,可能需要做一些适配。
CART 分类树和回归树的内容各自在一个类中,分类树为 CartClassifier 类,回归树为 CartRegression 类。
数据结构设计
二叉树设计
struct BinTreeNode {
std::string threshold_str_;
double threshold_ = -1;
std::string feature_name_;
std::shared_ptr<BinTreeNode> left_ = nullptr;
std::shared_ptr<BinTreeNode> right_ = nullptr;
[[nodiscard]] std::shared_ptr<BinTreeNode> copy() const {
auto node = std::make_shared<BinTreeNode>();
node->threshold_ = threshold_;
node->threshold_str_ = threshold_str_;
node->feature_name_ = feature_name_;
if(left_) node->left_ = left_->copy();
if(right_) node->right_ = right_->copy();
return node;
}
};
copy 模块用于二叉树结点的深复制,包括复制本身及其所有的子结点。
结点信息设计
struct Info {
std::shared_ptr<BinTreeNode> tree_;
size_t num_leaf_ = 0;
double a = 0;
std::pair<bool, std::string> key_str_{};
std::pair<bool, double> key_{};
};
实际上结点信息可以直接存储到二叉树结点 BinTreeNode 中。分开是为了保证代码的语义清晰,易于理解。
分类树
训练
shared_ptr<BinTreeNode> CartClassifier::train {
feature_names_ = feature_names;
tree_ = (X, y);
tree_;
}
微信扫一扫,关注极客日志
微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- Base64 字符串编码/解码
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
- Base64 文件转换器
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online
- Markdown转HTML
将 Markdown(GFM)转为 HTML 片段,浏览器内 marked 解析;与 HTML转Markdown 互为补充。 在线工具,Markdown转HTML在线工具,online
(const vector<vector<string>>& X,
const vector<string>& y,
const vector<string>& feature_names)
create_tree
return
训练函数通过传递常量引用形参,防止训练集和属性集被篡改。如需要修改,可以在函数内部设置副本,针对副本进行修改。create_tree 是创建 CART 分类树的核心函数。
shared_ptr<BinTreeNode> CartClassifier::create_tree(const vector<vector<string>>& X,
const vector<string>& y) {
auto tree = make_shared<BinTreeNode>();
if (unordered_set(y.begin(), y.end()).size() == 1) {
tree->threshold_str_ = y.front();
return tree;
}
if (y.size() <= min_samples_split_ || set(X.begin(), X.end()).size() == 1) {
tree->threshold_str_ = majority_y(y);
return tree;
}
auto [best_split_point, best_feature_index] = choose_best_point_to_split(X, y);
const string best_feature_name = feature_names_[best_feature_index];
vector<vector<string>> sub_X1, sub_X2;
vector<string> sub_y1, sub_y2;
for (int i = 0; i < X.size(); i++)
if (X[i][best_feature_index] == best_split_point) {
sub_X1.emplace_back(X[i]);
sub_y1.emplace_back(y[i]);
} else {
sub_X2.emplace_back(X[i]);
sub_y2.emplace_back(y[i]);
}
tree->feature_name_ = best_feature_name;
tree->threshold_str_ = best_split_point;
tree->left_ = create_tree(sub_X1, sub_y1);
tree->right_ = create_tree(sub_X2, sub_y2);
return tree;
}
create_tree 函数是一个递归创建决策树的过程。首先判断三种递归中止条件:
X 中样本全部属于同一类别;
- 当前节点样本数小于
min_samples_split_;
- 属性集上的取值均相同
若满足终止条件,则选择 y 中最多的类别作为结果返回。若未满足终止条件,依次执行以下步骤:
- 根据基尼指数从属性值中选择最优分裂属性的最优切分点,具体过程如
choose_best_point_to_split 函数所示;
- 根据最优切分点对子树进行划分;
- 对于其子树再继续执行
create_tree 函数完成划分过程。
string CartClassifier::majority_y(const vector<string>& y) {
unordered_map<string, int> y_count;
for (const string& v : y) {
if (!y_count.contains(v)) y_count[v] = 0;
++y_count[v];
}
return ranges::max_element(y_count, [](const pair<string, int>& a, const pair<string, int>& b) {
return a.second < b.second;
})->first;
}
majority_y 用于计算节点中出现次数最多的类别。包含以下步骤:
- 初始化一个空映射;
- 遍历
y 并对其元素进行计数;
- 从映射中查找出现次数最多的类别。
pair<string, int> CartClassifier::choose_best_point_to_split(const vector<vector<string>>& X,
const vector<string>& y) {
string best_split_point;
int best_feature_index = -1;
double best_gini_index = numeric_limits<double>::infinity();
const size_t num_feature = X[0].size();
for (int i = 0; i < num_feature; i++)
{
unordered_set<string> split_points;
for (const vector<string>& x : X) split_points.emplace(x[i]);
for (const string& split_point : split_points)
{
vector<string> sub_y_left, sub_y_right;
for (int j = 0; j < X.size(); j++)
if (X[j][i] == split_point) sub_y_left.emplace_back(y[j]);
else sub_y_right.emplace_back(y[j]);
const double gini_impurity_left = cal_gini_impurity(sub_y_left);
const double gini_impurity_right = cal_gini_impurity(sub_y_right);
const double pro_left = static_cast<double>(sub_y_left.size()) / static_cast<double>(y.size()),
pro_right = static_cast<double>(sub_y_right.size()) / static_cast<double>(y.size());
if (const double gini_index = cal_gini_index(pro_left, pro_right, gini_impurity_left, gini_impurity_right);
best_gini_index > gini_index)
{
best_gini_index = gini_index;
best_feature_index = i;
best_split_point = split_point;
}
}
}
return {best_split_point, best_feature_index};
}
choose_best_point_to_split 是 CART 分类树中最核心的函数,该函数负责选择最优切分点。根据前面的理论推导,该函数的目的是计算取得最大基尼增益的属性值。该函数遍历每个属性的每个属性值,根据是否等于属性值(二分类问题)将数据集分割到左右子树,依次计算左右子树的基尼不纯度 G(left) 和 G(right),以及左右子树中数据样本在总样本中占的比例 P(left) 和 P(right),并且将 G(left), G(right), P(left), P(right) 代入 cal_gini_index 函数中计算基尼指数。最后选出具有最小基尼指数的属性值,作为当前节点的最优切分点,并返回最优切分点和最优分裂属性索引。
double CartClassifier::cal_gini_impurity(const vector<string>& y) {
unordered_map<string, int> y_count;
for (const string& v : y) {
if (!y_count.contains(v)) y_count[v] = 0;
++y_count[v];
}
double gini_impurity = 1;
const auto num_samples = static_cast<double>(y.size());
for (const int& k : y_count | views::values) {
const double prob = k / num_samples;
gini_impurity -= prob * prob;
}
return gini_impurity;
}
cal_gini_impurity 用于计算基尼不纯度,包含以下步骤:
- 分析导入的数据集的最后一列(一般默认为数据类别),根据不同类别按出现次数统计到分类字典中;
- 遍历该字典,根据公式用 1 减去不同的类分布概率的平方和,得到最终的基尼不纯度。
double CartClassifier::cal_gini_index(const double pro_left, const double pro_right,
const double gini_impurity_left, const double gini_impurity_right) {
return pro_left * gini_impurity_left + pro_right * gini_impurity_right;
}
cal_gini_index 通过公式计算基尼指数。
预测
vector<string> CartClassifier::predict(const vector<vector<string>>& X) {
vector<string> y_preds;
for (const vector<string>& x : X) y_preds.emplace_back(classify(tree_, x));
return y_preds;
}
遍历测试集 X 的每个样本,使用 classify 函数分别对其进行预测,最终返回拼接好的预测结果。
string CartClassifier::classify(const shared_ptr<BinTreeNode>& tree, const vector<string>& x) {
const string& first_str = tree->feature_name_;
const size_t feature_index = distance(feature_names_.begin(), ranges::find(feature_names_, first_str));
const string& current_value = x[feature_index];
if (tree->left_ && current_value == tree->threshold_str_)
return classify(tree->left_, x);
if (tree->right_ && current_value != tree->threshold_str_)
return classify(tree->right_, x);
return tree->threshold_str_;
}
通过调用 classify 进行预测分类。参数 tree 的根节点代表属性,根节点的左右孩子节点代表属性的取值及路由方向。在递归遍历过程中,从根节点开始,递归遍历 CART 分类树,最终路由到某个叶子节点,叶子节点上的值即为该决策树的预测结果。
剪枝
vector<shared_ptr<BinTreeNode>> CartClassifier::pruning(const vector<vector<string>>& X,
const vector<string>& y) {
return split_n_best_trees(X, y);
}
函数 pruning 根据不同的 α 区间生成不同剪枝程度的决策树集合。集合中越后面的决策树,剪枝程度越高。
vector<shared_ptr<BinTreeNode>> CartClassifier::split_n_best_trees(const vector<vector<string>>& X,
const vector<string>& y) {
vector<shared_ptr<BinTreeNode>> trees;
shared_ptr<BinTreeNode> tree = tree_->copy();
while (tree)
if (shared_ptr<BinTreeNode> best_tree = split_1_best_trees(tree, X, y)) {
trees.emplace_back(best_tree);
tree = best_tree->copy();
} else
tree = nullptr;
return trees;
}
split_n_best_trees 函数通过调用 split_1_best_trees 函数递归生成 n 棵预测误差最小的树,每一次递归的初始树均为上一次递归得到的最优剪枝树。为了在递归过程中不破坏上一轮得到的最优剪枝树,使用了深拷贝。
shared_ptr<BinTreeNode> CartClassifier::split_1_best_trees(const shared_ptr<BinTreeNode>& tree,
const vector<vector<string>>& X,
const vector<string>& y) {
vector<Info> infoSet;
const size_t NT = X.size();
calErrorRatio(tree, X, y, NT, infoSet);
if (infoSet.empty()) return nullptr;
double baseValue = 1;
int bestNode = 0;
for (int i = 0; i < infoSet.size(); i++)
if (infoSet[i].a < baseValue) {
baseValue = infoSet[i].a;
bestNode = i;
} else if (infoSet[i].a == baseValue && infoSet[i].num_leaf_ > infoSet[bestNode].num_leaf_)
bestNode = i;
return prunBranch(tree, X, y, infoSet[bestNode]);
}
函数 split_1_best_tree 负责递归计算 α 值,并且选出 α 值最小的剪枝树。当前树的深度大于 1 时,开始进行 CCP 的迭代剪枝。在每次迭代内部,对每个分支节点进行 gi(t) 的计算,并选取最小值对应的子树进行剪枝。如果求得的最小 gi(t) 对应的子树有多个,则优先选取节点数目最多的子树作为修剪的对象。
Info CartClassifier::calErrorRatio(const shared_ptr<BinTreeNode>& tree, const vector<vector<string>>& X,
const vector<string>& y, const size_t NT, vector<Info>& infoSet) {
const string_view firstFeat = tree->feature_name_;
const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
if (tree->left_ && (tree->left_->left_ || tree->left_->right_)) {
vector<vector<string>> sub_X;
vector<string> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] == tree->threshold_str_) {
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
Info info = calErrorRatio(tree->left_, sub_X, sub_y, NT, infoSet);
info.key_str_ = {true, tree->threshold_str_};
infoSet.emplace_back(info);
}
if (tree->right_ && (tree->right_->left_ || tree->right_->right_)) {
vector<vector<string>> sub_X;
vector<string> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] != tree->threshold_str_) {
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
Info info = calErrorRatio(tree->right_, sub_X, sub_y, NT, infoSet);
info.key_str_ = {false, tree->threshold_str_};
infoSet.emplace_back(info);
}
const double Ct = static_cast<double>(nodeError(y)) / static_cast<double>(NT);
const double CTt = static_cast<double>(leafError(tree, X, y)) / static_cast<double>(NT);
const size_t Nt = getNumLeaf(tree);
const double a = Nt == 1 ? 2 : (Ct - CTt) / static_cast<double>(Nt - 1);
return {tree, Nt, a};
}
每次迭代中 gi(t) 的计算,也就是 calErrorRatio 函数。该函数主要计算节点 t 的误差率 C(t)、节点 t 对应子树 Tt 的误差率 C(Tt)、子树叶子节点的数目 |Tt|。gi(t) 的计算采用递归的方法,最终将所有 info 合并成节点信息集合。
size_t CartClassifier::nodeError(const vector<string>& y) {
string majorClass = majority_y(y);
return ranges::count_if(y, [&majorClass](const string& v) { return v != majorClass; });
}
size_t CartClassifier::leafError(const shared_ptr<BinTreeNode>& tree, const vector<vector<string>>& X,
const vector<string>& y) {
size_t error = 0;
for (int i = 0; i < X.size(); i++)
if (classify(tree, X[i]) != y[i]) ++error;
return error;
}
size_t CartClassifier::getNumLeaf(const shared_ptr<BinTreeNode>& tree) {
size_t numLeafs = 0;
if (tree->left_) numLeafs += getNumLeaf(tree->left_);
if (tree->right_) numLeafs += getNumLeaf(tree->right_);
if (!tree->left_ && !tree->right_) ++numLeafs;
return numLeafs;
}
shared_ptr<BinTreeNode> CartClassifier::prunBranch(const shared_ptr<BinTreeNode>& tree,
const vector<vector<string>>& X,
const vector<string>& y,
const Info& infoBran) {
const string_view firstFeat = tree->feature_name_;
const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
if (tree->left_) {
vector<vector<string>> sub_X;
vector<string> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] == tree->threshold_str_) {
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
const string majorClass = majority_y(sub_y);
if (infoBran.key_str_.first && infoBran.key_str_.second == tree->threshold_str_ &&
tree->left_ == infoBran.tree_) {
tree->left_ = make_shared<BinTreeNode>();
tree->left_->threshold_str_ = majorClass;
return tree;
}
tree->left_ = prunBranch(tree->left_, sub_X, sub_y, infoBran);
}
if (tree->right_) {
vector<vector<string>> sub_X;
vector<string> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] != tree->threshold_str_) {
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
const string majorClass = majority_y(sub_y);
if (!infoBran.key_str_.first && infoBran.key_str_.second == tree->threshold_str_ &&
tree->right_ == infoBran.tree_) {
tree->right_ = make_shared<BinTreeNode>();
tree->right_->threshold_str_ = majorClass;
return tree;
}
tree->right_ = prunBranch(tree->right_, sub_X, sub_y, infoBran);
}
return tree;
}
应用注意事项
- 为方便理解,代码仅考虑了离散字符串的分类,并未考虑其他离散值和连续值的分类,实际生产过程可能需要补充;
- 原则上来说,代码数据集中的字符串均需要通过编码(分类算法中编码无限制),以提升效率。为方便理解,本文章使用原始字符串,不影响结果;
- 代码中的
feature_name 仅作画图需要,实际生产如无该需求,可以去掉该变量;
- 对于 CCP 误差的计算,scikit-learn 使用基尼不纯度进行代替,因其不用每次使用预测计算,提高了效率。但基尼不纯度与误差之间仅具有相关性,无法通过基尼不纯度推导出误差,仅用作近似计算;
- 代码未考虑缺失值的处理;
- 代码没有适配多线程场景;
- 其他可能的算法时空复杂度的优化。
回归树
训练
CartRegressor 的创建和训练过程与 CartClassifier 类似。最重要的区别在于模型训练时切分点的选取。
shared_ptr<BinTreeNode> CartRegressor::train(const vector<vector<double>>& X,
const vector<double>& y,
const vector<string>& feature_names) {
feature_names_ = feature_names;
tree_ = create_tree(X, y);
return tree_;
}
train 的入参类型发生了变化,这是因为回归树使用的是连续类型数据。
shared_ptr<BinTreeNode> CartRegressor::create_tree(const vector<vector<double>>& X, const vector<double>& y) {
auto tree = make_shared<BinTreeNode>();
if (unordered_set(y.begin(), y.end()).size() == 1) {
tree->threshold_ = y.front();
return tree;
}
if (y.size() <= min_samples_split_ || set(X.begin(), X.end()).size() == 1) {
tree->threshold_ = accumulate(y.begin(), y.end(), 0.) / static_cast<double>(y.size());
return tree;
}
auto [best_split_point, best_feature_index] = choose_best_point_to_split(X, y);
const string_view best_feature_name = feature_names_[best_feature_index];
vector<vector<double>> sub_X1, sub_X2;
vector<double> sub_y1, sub_y2;
for (int i = 0; i < X.size(); i++)
if (X[i][best_feature_index] <= best_split_point) {
sub_X1.emplace_back(X[i]);
sub_y1.emplace_back(y[i]);
} else {
sub_X2.emplace_back(X[i]);
sub_y2.emplace_back(y[i]);
}
tree->feature_name_ = best_feature_name;
tree->threshold_ = best_split_point;
tree->left_ = create_tree(sub_X1, sub_y1);
tree->right_ = create_tree(sub_X2, sub_y2);
return tree;
}
在函数 create_tree 中,主要有 3 处与分类树不同:
- 当满足递归终止条件'节点样本数小于
min_samples_split_'时,返回的预测值是该集合中所有目标变量的平均值;
- 在
choose_best_point_to_split 函数中,在回归树中采用'平方误差最小'的原则来选择最优切分点;
- 使用最优属性和最优切分点划分数据集时相较分类树(处理匹配字符串'≤'和'>'的代码逻辑)做略微调整。
pair<double, int> CartRegressor::choose_best_point_to_split(const vector<vector<double>>& X,
const vector<double>& y) {
double best_split_point = 0, best_loss_all = numeric_limits<double>::infinity();
int best_feature_index = -1;
const size_t num_feature = X[0].size();
for (int i = 0; i < num_feature; ++i)
{
set<double> unique_feature_value;
vector<double> split_points;
for (const vector<double>& x : X) unique_feature_value.emplace(x[i]);
auto lit = unique_feature_value.begin(), rit = lit;
++rit;
while (rit != unique_feature_value.end()) {
split_points.emplace_back((*lit + *rit) / 2);
++lit;
++rit;
}
for (const double split_point : split_points) {
vector<double> sub_y_left, sub_y_right;
for (int j = 0; j < X.size(); j++)
if (X[j][i] <= split_point) sub_y_left.emplace_back(y[j]);
else sub_y_right.emplace_back(y[j]);
const double sub_y_left_mean = accumulate(sub_y_left.begin(), sub_y_left.end(), 0.) /
static_cast<double>(sub_y_left.size()),
sub_y_right_mean = accumulate(sub_y_right.begin(), sub_y_right.end(), 0.) /
static_cast<double>(sub_y_right.size());
double loss_left = 0, loss_right = 0;
for (const double j : sub_y_left) loss_left += pow(j - sub_y_left_mean, 2);
for (const double j : sub_y_right) loss_right += pow(j - sub_y_right_mean, 2);
if (const double loss_all = loss_left + loss_right; best_loss_all > loss_all) {
best_loss_all = loss_all;
best_feature_index = i;
best_split_point = split_point;
}
}
}
return {best_split_point, best_feature_index};
}
choose_best_point_to_split 遍历所有属性值时,回归树中不再计算基尼不纯度和基尼增益,而是针对回归问题计算损失函数。分别计算了使用当前切分点划分的左右子树的残差平方和,再计算左右子树的总残差平方和。最后选出取得最小损失函数的切分点和属性索引,作为最优切分点和最优分裂属性。
预测
vector<double> CartRegressor::predict(const vector<vector<double>>& X) {
vector<double> y_preds;
for (const vector<double>& x : X) y_preds.emplace_back(regression(tree_, x));
return y_preds;
}
double CartRegressor::regression(const shared_ptr<BinTreeNode>& tree, const vector<double>& x) {
const string& first_str = tree->feature_name_;
const size_t feature_index = distance(feature_names_.begin(), ranges::find(feature_names_, first_str));
const double current_value = x[feature_index];
if (tree->left_ && current_value <= tree->threshold_)
return regression(tree->left_, x);
if (tree->right_ && current_value > tree->threshold_)
return regression(tree->right_, x);
return tree->threshold_;
}
由于 CART 回归树与分类树的预测过程几乎完全相同,在此不做赘述。
剪枝
vector<shared_ptr<BinTreeNode>> CartRegressor::pruning(const vector<vector<double>>& X,
const vector<double>& y) {
return split_n_best_trees(X, y);
}
vector<shared_ptr<BinTreeNode>> CartRegressor::split_n_best_trees(const vector<vector<double>>& X,
const vector<double>& y) {
vector<shared_ptr<BinTreeNode>> trees;
shared_ptr<BinTreeNode> tree = tree_->copy();
while (tree)
if (shared_ptr<BinTreeNode> best_tree = split_1_best_trees(tree, X, y)) {
trees.emplace_back(best_tree);
tree = best_tree->copy();
} else
tree = nullptr;
return trees;
}
shared_ptr<BinTreeNode> CartRegressor::split_1_best_trees(const shared_ptr<BinTreeNode>& tree,
const vector<vector<double>>& X,
const vector<double>& y) {
vector<Info> infoSet;
const size_t NT = X.size();
calErrorRatio(tree, X, y, NT, infoSet);
if (infoSet.empty()) return nullptr;
double baseValue = 1;
int bestNode = 0;
for (int i = 0; i < infoSet.size(); i++)
if (infoSet[i].a < baseValue) {
baseValue = infoSet[i].a;
bestNode = i;
} else if (infoSet[i].a == baseValue && infoSet[i].num_leaf_ > infoSet[bestNode].num_leaf_)
bestNode = i;
return prunBranch(tree, X, y, infoSet[bestNode]);
}
Info CartRegressor::calErrorRatio(const shared_ptr<BinTreeNode>& tree, const vector<vector<double>>& X,
const vector<double>& y, size_t NT, vector<Info>& infoSet) {
const string_view firstFeat = tree->feature_name_;
const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
if (tree->left_ && (tree->left_->left_ || tree->left_->right_)) {
vector<vector<double>> sub_X;
vector<double> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] <= tree->threshold_) {
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
Info info = calErrorRatio(tree->left_, sub_X, sub_y, NT, infoSet);
info.key_ = {true, tree->threshold_};
infoSet.emplace_back(info);
}
if (tree->right_ && (tree->right_->left_ || tree->right_->right_)) {
vector<vector<double>> sub_X;
vector<double> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] > tree->threshold_) {
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
Info info = calErrorRatio(tree->right_, sub_X, sub_y, NT, infoSet);
info.key_ = {false, tree->threshold_};
infoSet.emplace_back(info);
}
const double Rt = static_cast<double>(nodeError(y)) / static_cast<double>(NT);
const double RTt = static_cast<double>(leafError(tree, X, y)) / static_cast<double>(NT);
const size_t Nt = getNumLeaf(tree);
const double a = Nt == 1 ? 2 : (Rt - RTt) / static_cast<double>(Nt - 1);
return {tree, Nt, a};
}
size_t CartRegressor::nodeError(const vector<double>& y) {
const double mean_y = accumulate(y.begin(), y.end(), 0.) / static_cast<double>(y.size());
size_t error = 0;
for (const double& val : y) error += static_cast<size_t>(pow(val - mean_y, 2));
return error;
}
size_t CartRegressor::leafError(const shared_ptr<BinTreeNode>& tree, const vector<vector<double>>& X,
const vector<double>& y) {
size_t error = 0;
for (int i = 0; i < X.size(); i++) {
const double pred = regression(tree, X[i]);
error += static_cast<size_t>(pow(pred - y[i], 2));
}
return error;
}
size_t CartRegressor::getNumLeaf(const shared_ptr<BinTreeNode>& tree) {
size_t numLeafs = 0;
if (tree->left_) numLeafs += getNumLeaf(tree->left_);
if (tree->right_) numLeafs += getNumLeaf(tree->right_);
if (!tree->left_ && !tree->right_) ++numLeafs;
return numLeafs;
}
shared_ptr<BinTreeNode> CartRegressor::prunBranch(const shared_ptr<BinTreeNode>& tree,
const vector<vector<double>>& X,
const vector<double>& y,
const Info& infoBran) {
const string_view firstFeat = tree->feature_name_;
const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
if (tree->left_) {
vector<vector<double>> sub_X;
vector<double> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] <= tree->threshold_) {
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
const double mean_val = accumulate(sub_y.begin(), sub_y.end(), 0.) / static_cast<double>(sub_y.size());
if (infoBran.key_.first && abs(infoBran.key_.second - tree->threshold_) < 1e-9 &&
tree->left_ == infoBran.tree_) {
tree->left_ = make_shared<BinTreeNode>();
tree->left_->threshold_ = mean_val;
return tree;
}
tree->left_ = prunBranch(tree->left_, sub_X, sub_y, infoBran);
}
if (tree->right_) {
vector<vector<double>> sub_X;
vector<double> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] > tree->threshold_) {
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
const double mean_val = accumulate(sub_y.begin(), sub_y.end(), 0.) / static_cast<double>(sub_y.size());
if (!infoBran.key_.first && abs(infoBran.key_.second - tree->threshold_) < 1e-9 &&
tree->right_ == infoBran.tree_) {
tree->right_ = make_shared<BinTreeNode>();
tree->right_->threshold_ = mean_val;
return tree;
}
tree->right_ = prunBranch(tree->right_, sub_X, sub_y, infoBran);
}
return tree;
}
回归树的剪枝与分类树类似,不同点在于回归树计算误差使用的是均方差。
应用注意事项
- 代码中的
feature_name 仅作画图需要,实际生产如无该需求,可以去掉该变量;
- 代码未考虑缺失值的处理;
- 分类树和回归树中的 CCP 算法,仅在误差计算中有区别。分类树中可以使用基尼系数或误分类率(从效率层面,推荐使用基尼系数),回归树中使用均方差;
- 代码没有适配多线程场景;
- 其他可能的算法时空复杂度的优化。