35 #ifndef RF_VISITORS_HXX
36 #define RF_VISITORS_HXX
39 # include "vigra/hdf5impex.hxx"
41 #include <vigra/windows.h>
45 #include <vigra/metaprogramming.hxx>
46 #include <vigra/multi_pointoperators.hxx>
141 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
147 Feature_t & features,
150 ignore_argument(tree,split,parent,leftChild,rightChild,features,labels);
162 template<
class RF,
class PR,
class SM,
class ST>
165 ignore_argument(rf,pr,sm,st,index);
174 template<
class RF,
class PR>
177 ignore_argument(rf,pr);
186 template<
class RF,
class PR>
189 ignore_argument(rf,pr);
204 template<
class TR,
class IntT,
class TopT,
class Feat>
207 ignore_argument(tr,index,node_t,features);
214 template<
class TR,
class IntT,
class TopT,
class Feat>
253 template <
class Visitor,
class Next = StopVisiting>
263 next_(next), visitor_(visitor)
268 next_(stop_), visitor_(visitor)
271 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
272 void visit_after_split( Tree & tree,
277 Feature_t & features,
280 if(visitor_.is_active())
281 visitor_.visit_after_split(tree, split,
282 parent, leftChild, rightChild,
284 next_.visit_after_split(tree, split, parent, leftChild, rightChild,
288 template<
class RF,
class PR,
class SM,
class ST>
289 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st,
int index)
291 if(visitor_.is_active())
292 visitor_.visit_after_tree(rf, pr, sm, st, index);
293 next_.visit_after_tree(rf, pr, sm, st, index);
296 template<
class RF,
class PR>
297 void visit_at_beginning(RF & rf, PR & pr)
299 if(visitor_.is_active())
300 visitor_.visit_at_beginning(rf, pr);
301 next_.visit_at_beginning(rf, pr);
303 template<
class RF,
class PR>
304 void visit_at_end(RF & rf, PR & pr)
306 if(visitor_.is_active())
307 visitor_.visit_at_end(rf, pr);
308 next_.visit_at_end(rf, pr);
311 template<
class TR,
class IntT,
class TopT,
class Feat>
312 void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
314 if(visitor_.is_active())
315 visitor_.visit_external_node(tr, index, node_t,features);
316 next_.visit_external_node(tr, index, node_t,features);
318 template<
class TR,
class IntT,
class TopT,
class Feat>
319 void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
321 if(visitor_.is_active())
322 visitor_.visit_internal_node(tr, index, node_t,features);
323 next_.visit_internal_node(tr, index, node_t,features);
328 if(visitor_.is_active() && visitor_.has_value())
329 return visitor_.return_val();
330 return next_.return_val();
354 template<
class A,
class B>
355 detail::VisitorNode<A, detail::VisitorNode<B> >
368 template<
class A,
class B,
class C>
369 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
384 template<
class A,
class B,
class C,
class D>
385 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
386 detail::VisitorNode<D> > > >
403 template<
class A,
class B,
class C,
class D,
class E>
404 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
405 detail::VisitorNode<D, detail::VisitorNode<E> > > > >
425 template<
class A,
class B,
class C,
class D,
class E,
427 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
428 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
450 template<
class A,
class B,
class C,
class D,
class E,
452 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
453 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
454 detail::VisitorNode<G> > > > > > >
456 D & d, E & e, F & f, G & g)
478 template<
class A,
class B,
class C,
class D,
class E,
479 class F,
class G,
class H>
480 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
481 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
482 detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
509 template<
class A,
class B,
class C,
class D,
class E,
510 class F,
class G,
class H,
class I>
511 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
512 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
513 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
541 template<
class A,
class B,
class C,
class D,
class E,
542 class F,
class G,
class H,
class I,
class J>
543 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
544 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
545 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
546 detail::VisitorNode<J> > > > > > > > > >
587 bool adjust_thresholds;
597 adjust_thresholds(
false), tree_id(0), last_node_id(0), current_label(0)
599 struct MarginalDistribution
602 Int32 leftTotalCounts;
604 Int32 rightTotalCounts;
611 struct TreeOnlineInformation
613 std::vector<MarginalDistribution> mag_distributions;
614 std::vector<IndexList> index_lists;
616 std::map<int,int> interior_to_index;
618 std::map<int,int> exterior_to_index;
622 std::vector<TreeOnlineInformation> trees_online_information;
626 template<
class RF,
class PR>
630 trees_online_information.resize(rf.options_.tree_count_);
637 trees_online_information[tree_id].mag_distributions.clear();
638 trees_online_information[tree_id].index_lists.clear();
639 trees_online_information[tree_id].interior_to_index.clear();
640 trees_online_information[tree_id].exterior_to_index.clear();
645 template<
class RF,
class PR,
class SM,
class ST>
651 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
652 void visit_after_split( Tree & tree,
657 Feature_t & features,
661 int addr=tree.topology_.size();
662 if(split.createNode().typeID() == i_ThresholdNode)
664 if(adjust_thresholds)
667 linear_index=trees_online_information[tree_id].mag_distributions.size();
668 trees_online_information[tree_id].interior_to_index[addr]=linear_index;
669 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
671 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
672 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
674 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
675 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
677 double gap_left,gap_right;
679 gap_left=features(leftChild[0],split.bestSplitColumn());
680 for(i=1;i<leftChild.size();++i)
681 if(features(leftChild[i],split.bestSplitColumn())>gap_left)
682 gap_left=features(leftChild[i],split.bestSplitColumn());
683 gap_right=features(rightChild[0],split.bestSplitColumn());
684 for(i=1;i<rightChild.size();++i)
685 if(features(rightChild[i],split.bestSplitColumn())<gap_right)
686 gap_right=features(rightChild[i],split.bestSplitColumn());
687 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
688 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
694 linear_index=trees_online_information[tree_id].index_lists.size();
695 trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
697 trees_online_information[tree_id].index_lists.push_back(IndexList());
699 trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
700 std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
703 void add_to_index_list(
int tree,
int node,
int index)
707 TreeOnlineInformation &ti=trees_online_information[tree];
708 ti.index_lists[ti.exterior_to_index[node]].push_back(index);
710 void move_exterior_node(
int src_tree,
int src_index,
int dst_tree,
int dst_index)
714 trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
715 trees_online_information[src_tree].exterior_to_index.erase(src_index);
722 template<
class TR,
class IntT,
class TopT,
class Feat>
726 if(adjust_thresholds)
728 vigra_assert(node_t==i_ThresholdNode,
"We can only visit threshold nodes");
730 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
731 TreeOnlineInformation &ti=trees_online_information[tree_id];
732 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
733 if(value>m.gap_left && value<m.gap_right)
736 if(m.leftCounts[current_label]/
double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
746 Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
749 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
751 ++m.rightTotalCounts;
752 ++m.rightCounts[current_label];
757 ++m.rightCounts[current_label];
805 template<
class RF,
class PR,
class SM,
class ST>
809 if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
811 oobCount.resize(rf.ext_param_.row_count_, 0);
812 oobErrorCount.resize(rf.ext_param_.row_count_, 0);
815 for(
int l = 0; l < rf.ext_param_.row_count_; ++l)
822 .predictLabel(
rowVector(pr.features(), l))
823 != pr.response()(l,0))
834 template<
class RF,
class PR>
838 for(
int l=0; l < static_cast<int>(rf.ext_param_.row_count_); ++l)
842 oobError += double(oobErrorCount[l]) / oobCount[l];
880 void save(std::string filen, std::string pathn)
882 if(*(pathn.end()-1) !=
'/')
884 const char* filename = filen.c_str();
887 writeHDF5(filename, (pathn +
"breiman_error").c_str(), temp);
893 template<
class RF,
class PR>
894 void visit_at_beginning(RF & rf, PR &)
896 class_count = rf.class_count();
897 tmp_prob.
reshape(Shp(1, class_count), 0);
898 prob_oob.
reshape(Shp(rf.ext_param().row_count_,class_count), 0);
899 is_weighted = rf.options().predict_weighted_;
900 indices.resize(rf.ext_param().row_count_);
901 if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
903 oobCount.
reshape(Shp(rf.ext_param_.row_count_, 1), 0);
905 for(
int ii = 0; ii < rf.ext_param().row_count_; ++ii)
911 template<
class RF,
class PR,
class SM,
class ST>
912 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST &,
int index)
919 if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
921 ArrayVector<int> oob_indices;
922 ArrayVector<int> cts(class_count, 0);
923 std::random_shuffle(indices.
begin(), indices.
end());
924 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
926 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
928 oob_indices.push_back(indices[ii]);
929 ++cts[pr.response()(indices[ii], 0)];
932 for(
unsigned int ll = 0; ll < oob_indices.size(); ++ll)
935 ++oobCount[oob_indices[ll]];
940 int pos = rf.tree(index).getToLeaf(
rowVector(pr.features(),oob_indices[ll]));
941 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
942 rf.tree(index).parameters_,
945 for(
int ii = 0; ii < class_count; ++ii)
947 tmp_prob[ii] = node.prob_begin()[ii];
951 for(
int ii = 0; ii < class_count; ++ii)
952 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
954 rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
959 for(
int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
962 if(!sm.is_used()[ll])
970 int pos = rf.tree(index).getToLeaf(
rowVector(pr.features(),ll));
971 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
972 rf.tree(index).parameters_,
975 for(
int ii = 0; ii < class_count; ++ii)
977 tmp_prob[ii] = node.prob_begin()[ii];
981 for(
int ii = 0; ii < class_count; ++ii)
982 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
993 template<
class RF,
class PR>
997 int totalOobCount =0;
998 int breimanstyle = 0;
999 for(
int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1074 void save(std::string filen, std::string pathn)
1076 if(*(pathn.end()-1) !=
'/')
1078 const char* filename = filen.c_str();
1084 writeHDF5(filename, (pathn +
"per_tree_error").c_str(), temp);
1086 writeHDF5(filename, (pathn +
"per_tree_error_std").c_str(), temp);
1088 writeHDF5(filename, (pathn +
"breiman_error").c_str(), temp);
1090 writeHDF5(filename, (pathn +
"ulli_error").c_str(), temp);
1096 template<
class RF,
class PR>
1097 void visit_at_beginning(RF & rf, PR &)
1099 class_count = rf.class_count();
1100 if(class_count == 2)
1104 tmp_prob.
reshape(Shp(1, class_count), 0);
1105 prob_oob.
reshape(Shp(rf.ext_param().row_count_,class_count), 0);
1106 is_weighted = rf.options().predict_weighted_;
1110 if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
1112 oobCount.
reshape(Shp(rf.ext_param_.row_count_, 1), 0);
1113 oobErrorCount.
reshape(Shp(rf.ext_param_.row_count_,1), 0);
1117 template<
class RF,
class PR,
class SM,
class ST>
1118 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST &,
int index)
1123 for(
int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
1126 if(!sm.is_used()[ll])
1134 int pos = rf.tree(index).getToLeaf(
rowVector(pr.features(),ll));
1135 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
1136 rf.tree(index).parameters_,
1139 for(
int ii = 0; ii < class_count; ++ii)
1141 tmp_prob[ii] = node.prob_begin()[ii];
1145 for(
int ii = 0; ii < class_count; ++ii)
1146 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
1149 int label =
argMax(tmp_prob);
1151 if(label != pr.response()(ll, 0))
1156 ++oobErrorCount[ll];
1160 int breimanstyle = 0;
1161 int totalOobCount = 0;
1162 for(
int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1179 MultiArrayView<3, double> current_roc
1181 for(
int gg = 0; gg < current_roc.shape(2); ++gg)
1183 for(
int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1187 int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
1189 current_roc(pr.response()(ll, 0), pred, gg)+= 1;
1192 current_roc.
bindOuter(gg)/= totalOobCount;
1196 oob_per_tree[index] = double(wrong_oob)/double(total_oob);
1202 template<
class RF,
class PR>
1207 int totalOobCount =0;
1208 int breimanstyle = 0;
1209 for(
int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1260 int repetition_count_;
1264 void save(std::string filename, std::string prefix)
1266 prefix =
"variable_importance_" + prefix;
1279 : repetition_count_(rep_cnt)
1286 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
1297 Int32 const class_count = tree.ext_param_.class_count_;
1298 Int32 const column_count = tree.ext_param_.column_count_;
1307 if(split.createNode().typeID() == i_ThresholdNode)
1309 Node<i_ThresholdNode> node(split.createNode());
1311 += split.region_gini_ - split.minGini();
1321 template<
class RF,
class PR,
class SM,
class ST>
1325 Int32 column_count = rf.ext_param_.column_count_;
1326 Int32 class_count = rf.ext_param_.class_count_;
1336 typedef typename PR::FeatureWithMemory_t FeatureArray;
1337 typedef typename FeatureArray::value_type FeatureValue;
1339 FeatureArray features = pr.features();
1345 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1346 if(!sm.is_used()[ii])
1347 oob_indices.push_back(ii);
1353 #ifdef CLASSIFIER_TEST
1364 oob_right(Shp_t(1, class_count + 1));
1366 perm_oob_right (Shp_t(1, class_count + 1));
1370 for(iter = oob_indices.
begin();
1371 iter != oob_indices.
end();
1375 .predictLabel(
rowVector(features, *iter))
1376 == pr.response()(*iter, 0))
1379 ++oob_right[pr.response()(*iter,0)];
1381 ++oob_right[class_count];
1385 for(
int ii = 0; ii < column_count; ++ii)
1387 perm_oob_right.
init(0.0);
1389 backup_column.clear();
1390 for(iter = oob_indices.
begin();
1391 iter != oob_indices.
end();
1394 backup_column.push_back(features(*iter,ii));
1398 for(
int rr = 0; rr < repetition_count_; ++rr)
1401 int n = oob_indices.
size();
1402 for(
int jj = n-1; jj >= 1; --jj)
1403 std::swap(features(oob_indices[jj], ii),
1404 features(oob_indices[randint(jj+1)], ii));
1407 for(iter = oob_indices.
begin();
1408 iter != oob_indices.
end();
1412 .predictLabel(
rowVector(features, *iter))
1413 == pr.response()(*iter, 0))
1416 ++perm_oob_right[pr.response()(*iter, 0)];
1418 ++perm_oob_right[class_count];
1425 perm_oob_right /= repetition_count_;
1426 perm_oob_right -=oob_right;
1427 perm_oob_right *= -1;
1428 perm_oob_right /= oob_indices.
size();
1431 Shp_t(ii+1,class_count+1)) += perm_oob_right;
1433 for(
int jj = 0; jj < int(oob_indices.
size()); ++jj)
1434 features(oob_indices[jj], ii) = backup_column[jj];
1443 template<
class RF,
class PR,
class SM,
class ST>
1451 template<
class RF,
class PR>
1464 template<
class RF,
class PR,
class SM,
class ST>
1465 void visit_after_tree(RF& rf, PR &, SM &, ST &,
int index){
1466 if(index != rf.options().tree_count_-1) {
1467 std::cout <<
"\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 <<
"%]"
1468 <<
" (" << index+1 <<
" of " << rf.options().tree_count_ <<
") done" << std::flush;
1471 std::cout <<
"\r[" << std::setw(10) << 100.0 <<
"%]" << std::endl;
1475 template<
class RF,
class PR>
1476 void visit_at_end(RF
const & rf, PR
const &) {
1477 std::string a =
TOCS;
1478 std::cout <<
"all " << rf.options().tree_count_ <<
" trees have been learned in " << a << std::endl;
1481 template<
class RF,
class PR>
1482 void visit_at_beginning(RF
const & rf, PR
const &) {
1484 std::cout <<
"growing random forest, which will have " << rf.options().tree_count_ <<
" trees" << std::endl;
1532 void save(std::string, std::string)
1550 template<
class RF,
class PR>
1551 void visit_at_beginning(RF
const & rf, PR & pr)
1554 int n = rf.ext_param_.column_count_;
1557 corr_l.
reshape(Shp(n +1, 10));
1560 noise_l.
reshape(Shp(pr.features().shape(0), 10));
1562 for(
int ii = 0; ii <
noise.
size(); ++ii)
1564 noise[ii] = random.uniform53();
1565 noise_l[ii] = random.uniform53() > 0.5;
1567 bgfunc = ColumnDecisionFunctor( rf.ext_param_);
1568 tmp_labels.
reshape(pr.response().shape());
1573 template<
class RF,
class PR>
1574 void visit_at_end(RF
const &, PR
const &)
1583 for(
int jj = 0; jj < rC-1; ++jj)
1588 for(
int jj = 0; jj < rC; ++jj)
1594 FindMinMax<double> minmax;
1597 for(
int jj = 0; jj < rC; ++jj)
1604 for(
int jj = 0; jj < rC; ++jj)
1607 FindMinMax<double> minmax2;
1609 for(
int jj = 0; jj < rC; ++jj)
1615 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
1616 void visit_after_split( Tree &,
1621 Feature_t & features,
1624 if(split.createNode().typeID() == i_ThresholdNode)
1628 for(
int ii = 0; ii < parent.size(); ++ii)
1630 tmp_labels[parent[ii]]
1631 = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
1632 ++tmp_cc[tmp_labels[parent[ii]]];
1634 double region_gini = bgfunc.loss_of_region(tmp_labels,
1639 int n = split.bestSplitColumn();
1643 for(
int k = 0; k < features.shape(1); ++k)
1647 parent.begin(), parent.end(),
1649 wgini = (region_gini - bgfunc.min_gini_);
1653 for(
int k = 0; k < 10; ++k)
1657 parent.begin(), parent.end(),
1659 wgini = (region_gini - bgfunc.min_gini_);
1664 for(
int k = 0; k < 10; ++k)
1668 parent.begin(), parent.end(),
1670 wgini = (region_gini - bgfunc.min_gini_);
1674 bgfunc(labels, tmp_labels, parent.begin(), parent.end(),tmp_cc);
1675 wgini = (region_gini - bgfunc.min_gini_);
1679 region_gini = split.region_gini_;
1681 Node<i_ThresholdNode> node(split.createNode());
1684 +=split.region_gini_ - split.minGini();
1686 for(
int k = 0; k < 10; ++k)
1690 parent.begin(), parent.end(),
1691 parent.classCounts());
1697 for(
int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
1699 wgini = region_gini - split.min_gini_[k];
1702 split.splitColumns[k])
1706 for(
int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
1708 split.bgfunc(
columnVector(features, split.splitColumns[k]),
1710 parent.begin(), parent.end(),
1711 parent.classCounts());
1712 wgini = region_gini - split.bgfunc.min_gini_;
1714 split.splitColumns[k]) += wgini;
1721 SortSamplesByDimensions<Feature_t>
1722 sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
1723 std::partition(parent.begin(), parent.end(), sorter);
1733 #endif // RF_VISITORS_HXX
#define TIC
Definition: timing.hxx:322
void visit_at_end(RF &rf, PR &pr)
Definition: rf_visitors.hxx:994
MultiArray< 2, double > oob_per_tree
Definition: rf_visitors.hxx:1025
void visit_at_beginning(RF &rf, const PR &)
Definition: rf_visitors.hxx:627
MultiArray< 2, double > noise
Definition: rf_visitors.hxx:1505
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:727
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:671
MultiArray< 2, double > variable_importance_
Definition: rf_visitors.hxx:1259
const difference_type & shape() const
Definition: multi_array.hxx:1648
MultiArray< 2, double > breiman_per_tree
Definition: rf_visitors.hxx:1050
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition: rf_visitors.hxx:1322
MultiArray< 2, double > corr_noise
Definition: rf_visitors.hxx:1509
void visit_after_tree(RF &, PR &, SM &, ST &, int)
Definition: rf_visitors.hxx:646
const_iterator begin() const
Definition: array_vector.hxx:223
MultiArray< 2, double > similarity
Definition: rf_visitors.hxx:1521
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2861
Definition: rf_visitors.hxx:863
MultiArray< 4, double > oobroc_per_tree
Definition: rf_visitors.hxx:1067
Definition: rf_visitors.hxx:1495
ArrayVector< int > numChoices
Definition: rf_visitors.hxx:1529
Definition: rf_visitors.hxx:1230
MultiArrayView< N, T, StridedArrayTag > transpose() const
Definition: multi_array.hxx:1567
double oob_breiman
Definition: rf_visitors.hxx:1038
double oob_mean
Definition: rf_visitors.hxx:1028
MultiArray< 2, double > gini_missc
Definition: rf_visitors.hxx:1501
double return_val()
Definition: rf_visitors.hxx:225
difference_type_1 size() const
Definition: multi_array.hxx:1641
MultiArray< 2, double > distance
Definition: rf_visitors.hxx:1524
Definition: multi_fwd.hxx:63
void reset_tree(int tree_id)
Definition: rf_visitors.hxx:635
double oobError
Definition: rf_visitors.hxx:787
Definition: rf_visitors.hxx:254
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
void visit_internal_node(TR &, IntT, TopT, Feat &)
Definition: rf_visitors.hxx:215
void init(U const &initial)
Definition: array_vector.hxx:146
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition: rf_visitors.hxx:806
double oob_per_tree2
Definition: rf_visitors.hxx:1045
Definition: rf_split.hxx:831
MultiArray & init(const U &init)
Definition: multi_array.hxx:2851
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
void visit_at_end(RF &rf, PR &)
Definition: rf_visitors.hxx:1452
Definition: rf_visitors.hxx:1015
Definition: rf_visitors.hxx:583
double oob_breiman
Definition: rf_visitors.hxx:874
#define TOCS
Definition: timing.hxx:325
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:940
void writeHDF5(...)
Store array data in an HDF5 file.
Definition: rf_visitors.hxx:1460
double oob_std
Definition: rf_visitors.hxx:1031
void visit_internal_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition: rf_visitors.hxx:723
void visit_at_end(RF const &rf, PR const &pr)
Definition: rf_visitors.hxx:175
void visit_at_end(RF &rf, PR &pr)
Definition: rf_visitors.hxx:1203
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:697
Definition: rf_visitors.hxx:101
void visit_at_end(RF &rf, PR &)
Definition: rf_visitors.hxx:835
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:684
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude)
Definition: fftw3.hxx:1002
void visit_at_beginning(RF const &rf, PR const &pr)
Definition: rf_visitors.hxx:187
Definition: random.hxx:336
void visit_after_split(Tree &tree, Split &split, Region &parent, Region &leftChild, Region &rightChild, Feature_t &features, Label_t &labels)
Definition: rf_visitors.hxx:142
const_iterator end() const
Definition: array_vector.hxx:237
const_pointer data() const
Definition: array_vector.hxx:209
size_type size() const
Definition: array_vector.hxx:358
MultiArrayView subarray(difference_type p, difference_type q) const
Definition: multi_array.hxx:1528
void inspectMultiArray(...)
Call an analyzing functor at every element of a multi-dimensional array.
void visit_after_split(Tree &tree, Split &split, Region &, Region &, Region &, Feature_t &, Label_t &)
Definition: rf_visitors.hxx:1287
void visit_external_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition: rf_visitors.hxx:205
Definition: rf_visitors.hxx:782
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_visitors.hxx:163
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_visitors.hxx:1444
MultiArrayView< N-M, T, StrideTag > bindOuter(const TinyVector< Index, M > &d) const
Definition: multi_array.hxx:2184
Definition: rf_visitors.hxx:234
detail::VisitorNode< A > create_visitor(A &a)
Definition: rf_visitors.hxx:344