35 #ifndef VIGRA_RF_ALGORITHM_HXX
36 #define VIGRA_RF_ALGORITHM_HXX
38 #include "splices.hxx"
58 template<
class OrigMultiArray,
61 void choose(OrigMultiArray
const & in,
70 for(Iter iter = b; iter != e; ++iter, ++ii)
100 template<
class Feature_t,
class Response_t>
102 Response_t
const & response)
125 typedef std::vector<int> FeatureList_t;
126 typedef std::vector<double> ErrorList_t;
127 typedef FeatureList_t::iterator Pivot_t;
153 template<
class FeatureT,
156 class ErrorRateCallBack>
157 bool init(FeatureT
const & all_features,
158 ResponseT
const & response,
161 ErrorRateCallBack errorcallback)
163 bool ret_ = init(all_features, response, errorcallback);
166 vigra_precondition(std::distance(b, e) == static_cast<std::ptrdiff_t>(
selected.size()),
167 "Number of features in ranking != number of features matrix");
172 template<
class FeatureT,
175 bool init(FeatureT
const & all_features,
176 ResponseT
const & response,
181 return init(all_features, response, b, e, ecallback);
185 template<
class FeatureT,
187 bool init(FeatureT
const & all_features,
188 ResponseT
const & response)
190 return init(all_features, response, RFErrorCallback());
202 template<
class FeatureT,
204 class ErrorRateCallBack>
205 bool init(FeatureT
const & all_features,
206 ResponseT
const & response,
207 ErrorRateCallBack errorcallback)
215 selected.resize(all_features.shape(1), 0);
216 for(
unsigned int ii = 0; ii <
selected.size(); ++ii)
218 errors.resize(all_features.shape(1), -1);
219 errors.back() = errorcallback(all_features, response);
223 std::map<typename ResponseT::value_type, int> res_map;
224 std::vector<int> cts;
226 for(
int ii = 0; ii < response.shape(0); ++ii)
228 if(res_map.find(response(ii, 0)) == res_map.end())
230 res_map[response(ii, 0)] = counter;
234 cts[res_map[response(ii,0)]] +=1;
236 no_features = double(*(std::max_element(cts.begin(),
238 /
double(response.shape(0));
293 template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
295 ResponseT
const & response,
297 ErrorRateCallBack errorcallback)
299 VariableSelectionResult::FeatureList_t & selected = result.
selected;
300 VariableSelectionResult::ErrorList_t & errors = result.
errors;
301 VariableSelectionResult::Pivot_t & pivot = result.pivot;
302 int featureCount = features.shape(1);
304 if(!result.init(features, response, errorcallback))
308 vigra_precondition(static_cast<int>(selected.size()) == featureCount,
309 "forward_selection(): Number of features in Feature "
310 "matrix and number of features in previously used "
311 "result struct mismatch!");
315 int not_selected_size = std::distance(pivot, selected.end());
316 while(not_selected_size > 1)
318 std::vector<double> current_errors;
319 VariableSelectionResult::Pivot_t next = pivot;
320 for(
int ii = 0; ii < not_selected_size; ++ii, ++next)
322 std::swap(*pivot, *next);
324 detail::choose( features,
328 double error = errorcallback(cur_feats, response);
329 current_errors.push_back(error);
330 std::swap(*pivot, *next);
332 int pos = std::distance(current_errors.begin(),
333 std::min_element(current_errors.begin(),
334 current_errors.end()));
336 std::advance(next, pos);
337 std::swap(*pivot, *next);
338 errors[std::distance(selected.begin(), pivot)] = current_errors[pos];
340 std::copy(current_errors.begin(), current_errors.end(), std::ostream_iterator<double>(std::cerr,
", "));
341 std::cerr <<
"Choosing " << *pivot <<
" at error of " << current_errors[pos] << std::endl;
344 not_selected_size = std::distance(pivot, selected.end());
347 template<
class FeatureT,
class ResponseT>
349 ResponseT
const & response,
350 VariableSelectionResult & result)
395 template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
397 ResponseT
const & response,
399 ErrorRateCallBack errorcallback)
401 int featureCount = features.shape(1);
402 VariableSelectionResult::FeatureList_t & selected = result.
selected;
403 VariableSelectionResult::ErrorList_t & errors = result.
errors;
404 VariableSelectionResult::Pivot_t & pivot = result.pivot;
407 if(!result.init(features, response, errorcallback))
411 vigra_precondition(static_cast<int>(selected.size()) == featureCount,
412 "backward_elimination(): Number of features in Feature "
413 "matrix and number of features in previously used "
414 "result struct mismatch!");
416 pivot = selected.end() - 1;
418 int selected_size = std::distance(selected.begin(), pivot);
419 while(selected_size > 1)
421 VariableSelectionResult::Pivot_t next = selected.begin();
422 std::vector<double> current_errors;
423 for(
int ii = 0; ii < selected_size; ++ii, ++next)
425 std::swap(*pivot, *next);
427 detail::choose( features,
431 double error = errorcallback(cur_feats, response);
432 current_errors.push_back(error);
433 std::swap(*pivot, *next);
435 int pos = std::distance(current_errors.begin(),
436 std::min_element(current_errors.begin(),
437 current_errors.end()));
438 next = selected.begin();
439 std::advance(next, pos);
440 std::swap(*pivot, *next);
442 errors[std::distance(selected.begin(), pivot)-1] = current_errors[pos];
443 selected_size = std::distance(selected.begin(), pivot);
445 std::copy(current_errors.begin(), current_errors.end(), std::ostream_iterator<double>(std::cerr,
", "));
446 std::cerr <<
"Eliminating " << *pivot <<
" at error of " << current_errors[pos] << std::endl;
452 template<
class FeatureT,
class ResponseT>
454 ResponseT
const & response,
455 VariableSelectionResult & result)
492 template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
494 ResponseT
const & response,
496 ErrorRateCallBack errorcallback)
498 VariableSelectionResult::FeatureList_t & selected = result.
selected;
499 VariableSelectionResult::ErrorList_t & errors = result.
errors;
500 VariableSelectionResult::Pivot_t & iter = result.pivot;
501 int featureCount = features.shape(1);
503 if(!result.init(features, response, errorcallback))
507 vigra_precondition(static_cast<int>(selected.size()) == featureCount,
508 "forward_selection(): Number of features in Feature "
509 "matrix and number of features in previously used "
510 "result struct mismatch!");
514 for(; iter != selected.end(); ++iter)
518 detail::choose( features,
522 double error = errorcallback(cur_feats, response);
523 errors[std::distance(selected.begin(), iter)] = error;
525 std::copy(selected.begin(), iter+1, std::ostream_iterator<int>(std::cerr,
", "));
526 std::cerr <<
"Choosing " << *(iter+1) <<
" at error of " << error << std::endl;
532 template<
class FeatureT,
class ResponseT>
534 ResponseT
const & response,
535 VariableSelectionResult & result)
542 enum ClusterLeafTypes{c_Leaf = 95, c_Node = 99};
557 ClusterNode():NodeBase(){}
558 ClusterNode(
int nCol,
559 BT::T_Container_type & topology,
560 BT::P_Container_type & split_param)
561 : BT(nCol + 5, 5,topology, split_param)
571 ClusterNode( BT::T_Container_type
const & topology,
572 BT::P_Container_type
const & split_param,
574 :
NodeBase(5 , 5,topology, split_param, n)
580 ClusterNode( BT & node_)
585 BT::parameter_size_ += 0;
591 void set_index(
int in)
617 HC_Entry(
int p,
int l,
int a,
bool in)
618 : parent(p), level(l), addr(a), infm(in)
647 double dist_func(
double a,
double b)
649 return std::min(a, b);
655 template<
class Functor>
659 std::vector<int> stack;
660 stack.push_back(begin_addr);
661 while(!stack.empty())
663 ClusterNode node(topology_, parameters_, stack.
back());
667 if(node.columns_size() != 1)
669 stack.push_back(node.child(0));
670 stack.push_back(node.child(1));
678 template<
class Functor>
682 std::queue<HC_Entry> queue;
687 queue.push(
HC_Entry(parent,level,begin_addr, infm));
688 while(!queue.empty())
690 level = queue.front().level;
691 parent = queue.front().parent;
692 addr = queue.front().addr;
693 infm = queue.front().infm;
694 ClusterNode node(topology_, parameters_, queue.
front().addr);
698 parnt = ClusterNode(topology_, parameters_, parent);
701 bool istrue = tester(node, level, parnt, infm);
702 if(node.columns_size() != 1)
704 queue.push(
HC_Entry(addr, level +1,node.child(0),istrue));
705 queue.push(
HC_Entry(addr, level +1,node.child(1),istrue));
712 void save(std::string file, std::string prefix)
717 Shp(topology_.
size(),1),
721 Shp(parameters_.
size(), 1),
722 parameters_.
data()));
732 template<
class T,
class C>
736 std::vector<std::pair<int, int> > addr;
738 for(
int ii = 0; ii < distance.
shape(0); ++ii)
740 addr.push_back(std::make_pair(topology_.
size(), ii));
741 ClusterNode leaf(1, topology_, parameters_);
742 leaf.set_index(index);
744 leaf.columns_begin()[0] = ii;
747 while(addr.size() != 1)
752 double min_dist = dist((addr.begin()+ii_min)->second,
753 (addr.begin()+jj_min)->second);
754 for(
unsigned int ii = 0; ii < addr.size(); ++ii)
756 for(
unsigned int jj = ii+1; jj < addr.size(); ++jj)
758 if( dist((addr.begin()+ii_min)->second,
759 (addr.begin()+jj_min)->second)
760 > dist((addr.begin()+ii)->second,
761 (addr.begin()+jj)->second))
763 min_dist = dist((addr.begin()+ii)->second,
764 (addr.begin()+jj)->second);
776 ClusterNode firstChild(topology_,
778 (addr.begin() +ii_min)->first);
779 ClusterNode secondChild(topology_,
781 (addr.begin() +jj_min)->first);
782 col_size = firstChild.columns_size() + secondChild.columns_size();
784 int cur_addr = topology_.
size();
785 begin_addr = cur_addr;
787 ClusterNode parent(col_size,
790 ClusterNode firstChild(topology_,
792 (addr.begin() +ii_min)->first);
793 ClusterNode secondChild(topology_,
795 (addr.begin() +jj_min)->first);
796 parent.parameters_begin()[0] = min_dist;
797 parent.set_index(index);
799 std::merge(firstChild.columns_begin(), firstChild.columns_end(),
800 secondChild.columns_begin(),secondChild.columns_end(),
801 parent.columns_begin());
805 if(*parent.columns_begin() == *firstChild.columns_begin())
807 parent.child(0) = (addr.begin()+ii_min)->first;
808 parent.child(1) = (addr.begin()+jj_min)->first;
809 (addr.begin()+ii_min)->first = cur_addr;
811 to_desc = (addr.begin()+jj_min)->second;
812 addr.erase(addr.begin()+jj_min);
816 parent.child(1) = (addr.begin()+ii_min)->first;
817 parent.child(0) = (addr.begin()+jj_min)->first;
818 (addr.begin()+jj_min)->first = cur_addr;
820 to_desc = (addr.begin()+ii_min)->second;
821 addr.erase(addr.begin()+ii_min);
825 for(
int jj = 0 ; jj < static_cast<int>(addr.size()); ++jj)
829 double bla = dist_func(
830 dist(to_desc, (addr.begin()+jj)->second),
831 dist((addr.begin()+ii_keep)->second,
832 (addr.begin()+jj)->second));
834 dist((addr.begin()+ii_keep)->second,
835 (addr.begin()+jj)->second) = bla;
836 dist((addr.begin()+jj)->second,
837 (addr.begin()+ii_keep)->second) = bla;
858 bool operator()(Node& node)
871 template<
class Iter,
class DT>
876 Matrix<double> tmp_mem_;
879 Matrix<double> feats_;
886 template<
class Feat_T,
class Label_T>
889 Feat_T
const & feats,
890 Label_T
const & labls,
895 :tmp_mem_(_spl(a, b).size(), feats.shape(1)),
898 feats_(_spl(a,b).size(), feats.shape(1)),
899 labels_(_spl(a,b).size(),1),
905 copy_splice(_spl(a,b),
906 _spl(feats.shape(1)),
909 copy_splice(_spl(a,b),
910 _spl(labls.shape(1)),
916 bool operator()(Node& node)
920 int class_count = perm_imp.
shape(1) - 1;
922 for(
int kk = 0; kk < nPerm; ++kk)
925 for(
int ii = 0; ii <
rowCount(feats_); ++ii)
928 for(
int jj = 0; jj < node.columns_size(); ++jj)
930 if(node.columns_begin()[jj] != feats_.shape(1))
931 tmp_mem_(ii, node.columns_begin()[jj])
932 = tmp_mem_(index, node.columns_begin()[jj]);
936 for(
int ii = 0; ii <
rowCount(tmp_mem_); ++ii)
943 ++perm_imp(index,labels_(ii, 0));
945 ++perm_imp(index, class_count);
949 double node_status = perm_imp(index, class_count);
950 node_status /= nPerm;
951 node_status -= orig_imp(0, class_count);
953 node_status /= oob_size;
954 node.status() += node_status;
975 void save(std::string file, std::string prefix)
983 bool operator()(Node& node)
985 for(
int ii = 0; ii < node.columns_size(); ++ii)
986 variables(index, ii) = node.columns_begin()[ii];
1000 bool operator()(Nde & cur,
int , Nde parent,
bool )
1003 cur.status() = std::min(parent.status(), cur.status());
1030 std::ofstream graphviz;
1035 std::string
const gz)
1036 :features_(features), labels_(labels),
1037 graphviz(gz.c_str(), std::ios::out)
1039 graphviz <<
"digraph G\n{\n node [shape=\"record\"]";
1043 graphviz <<
"\n}\n";
1048 bool operator()(Nde & cur,
int , Nde parent,
bool )
1050 graphviz <<
"node" << cur.index() <<
" [style=\"filled\"][label = \" #Feats: "<< cur.columns_size() <<
"\\n";
1051 graphviz <<
" status: " << cur.status() <<
"\\n";
1052 for(
int kk = 0; kk < cur.columns_size(); ++kk)
1054 graphviz << cur.columns_begin()[kk] <<
" ";
1058 graphviz <<
"\"] [color = \"" <<cur.status() <<
" 1.000 1.000\"];\n";
1060 graphviz <<
"\"node" << parent.index() <<
"\" -> \"node" << cur.index() <<
"\";\n";
1080 int repetition_count_;
1086 void save(std::string filename, std::string prefix)
1088 std::string prefix1 =
"cluster_importance_" + prefix;
1092 prefix1 =
"vars_" + prefix;
1100 : repetition_count_(rep_cnt), clustering(clst)
1106 template<
class RF,
class PR>
1109 Int32 const class_count = rf.ext_param_.class_count_;
1110 Int32 const column_count = rf.ext_param_.column_count_+1;
1131 template<
class RF,
class PR,
class SM,
class ST>
1135 Int32 column_count = rf.ext_param_.column_count_ +1;
1136 Int32 class_count = rf.ext_param_.class_count_;
1140 typename PR::Feature_t & features
1141 =
const_cast<typename PR::Feature_t &
>(pr.features());
1148 if(rf.ext_param_.actual_msample_ < pr.features().shape(0)- 10000)
1152 for(
int ii = 0; ii < pr.features().shape(0); ++ii)
1153 indices.push_back(ii);
1154 std::random_shuffle(indices.begin(), indices.end());
1155 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1157 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 3000)
1159 oob_indices.push_back(indices[ii]);
1160 ++cts[pr.response()(indices[ii], 0)];
1166 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1167 if(!sm.is_used()[ii])
1168 oob_indices.push_back(ii);
1178 oob_right(Shp_t(1, class_count + 1));
1181 for(iter = oob_indices.
begin();
1182 iter != oob_indices.
end();
1186 .predictLabel(
rowVector(features, *iter))
1187 == pr.response()(*iter, 0))
1190 ++oob_right[pr.response()(*iter,0)];
1192 ++oob_right[class_count];
1197 perm_oob_right (Shp_t(2* column_count-1, class_count + 1));
1200 pc(oob_indices.
begin(), oob_indices.
end(),
1209 perm_oob_right /= repetition_count_;
1210 for(
int ii = 0; ii <
rowCount(perm_oob_right); ++ii)
1211 rowVector(perm_oob_right, ii) -= oob_right;
1213 perm_oob_right *= -1;
1214 perm_oob_right /= oob_indices.
size();
1223 template<
class RF,
class PR,
class SM,
class ST>
1231 template<
class RF,
class PR>
1271 template<
class FeatureT,
class ResponseT>
1273 ResponseT
const & response,
1280 if(features.shape(0) > 40000)
1287 RF.
learn(features, response,
1288 create_visitor(missc, progress));
1303 create_visitor(progress, ci));
1316 template<
class FeatureT,
class ResponseT>
1318 ResponseT
const & response,
1319 HClustering & linkage)
1326 template<
class Array1,
class Vector1>
1327 void get_ranking(Array1
const & in, Vector1 & out)
1329 std::map<double, int> mymap;
1330 for(
int ii = 0; ii < in.size(); ++ii)
1332 for(std::map<double, int>::reverse_iterator iter = mymap.rbegin(); iter!= mymap.rend(); ++iter)
1334 out.push_back(iter->second);
1340 #endif //VIGRA_RF_ALGORITHM_HXX
UInt32 uniformInt() const
Definition: random.hxx:464
double no_features
Definition: rf_algorithm.hxx:151
reference back()
Definition: array_vector.hxx:321
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:727
MultiArray< 2, double > cluster_stdev_
Definition: rf_algorithm.hxx:1079
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:671
Topology_type column_data() const
Definition: rf_nodeproxy.hxx:159
MultiArray< 2, double > cluster_importance_
Definition: rf_algorithm.hxx:1076
RandomForestOptions & samples_per_tree(double in)
specify the fraction of the total number of samples used per tree for learning.
Definition: rf_common.hxx:411
const difference_type & shape() const
Definition: multi_array.hxx:1648
Definition: rf_algorithm.hxx:1067
void visit_at_end(RF &rf, PR &)
Definition: rf_algorithm.hxx:1232
const_iterator begin() const
Definition: array_vector.hxx:223
void visit_at_beginning(RF const &rf, PR const &)
Definition: rf_algorithm.hxx:1107
NodeBase()
Definition: rf_nodeproxy.hxx:237
Definition: rf_algorithm.hxx:847
NormalizeStatus(double m)
Definition: rf_algorithm.hxx:854
Definition: rf_algorithm.hxx:996
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2861
Definition: rf_visitors.hxx:863
void forward_selection(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition: rf_algorithm.hxx:294
Definition: rf_visitors.hxx:1495
void backward_elimination(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition: rf_algorithm.hxx:396
Definition: rf_algorithm.hxx:611
Definition: rf_algorithm.hxx:872
Definition: rf_algorithm.hxx:83
difference_type_1 size() const
Definition: multi_array.hxx:1641
MultiArray< 2, double > distance
Definition: rf_visitors.hxx:1524
Definition: multi_fwd.hxx:63
Random forest version 2 (see also vigra::rf3::RandomForest for version 3)
Definition: random_forest.hxx:147
reference front()
Definition: array_vector.hxx:307
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
bool init(FeatureT const &all_features, ResponseT const &response, ErrorRateCallBack errorcallback)
Definition: rf_algorithm.hxx:205
void breadth_first_traversal(Functor &tester)
Definition: rf_algorithm.hxx:679
Definition: rf_algorithm.hxx:638
void cluster_permutation_importance(FeatureT const &features, ResponseT const &response, HClustering &linkage, MultiArray< 2, double > &distance)
Definition: rf_algorithm.hxx:1272
Definition: rf_algorithm.hxx:963
Parameter_type parameters_begin() const
Definition: rf_nodeproxy.hxx:207
Definition: metaprogramming.hxx:123
double oob_breiman
Definition: rf_visitors.hxx:874
ErrorList_t errors
Definition: rf_algorithm.hxx:146
void writeHDF5(...)
Store array data in an HDF5 file.
Definition: rf_visitors.hxx:1460
INT & typeID()
Definition: rf_nodeproxy.hxx:136
void cluster(MultiArrayView< 2, T, C > distance)
Definition: rf_algorithm.hxx:733
Definition: rf_algorithm.hxx:1024
MultiArray< 2, int > variables
Definition: rf_algorithm.hxx:1073
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:697
Definition: rf_visitors.hxx:101
void rank_selection(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition: rf_algorithm.hxx:493
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_algorithm.hxx:1224
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:684
Definition: random.hxx:336
Options object for the random forest.
Definition: rf_common.hxx:170
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition: rf_algorithm.hxx:1132
RandomForestOptions & use_stratification(RF_OptionTag in)
specify stratification strategy
Definition: rf_common.hxx:374
RandomForestOptions & tree_count(unsigned int in)
Definition: rf_common.hxx:500
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator
Definition: random_forest.hxx:941
const_iterator end() const
Definition: array_vector.hxx:237
const_pointer data() const
Definition: array_vector.hxx:209
void iterate(Functor &tester)
Definition: rf_algorithm.hxx:656
FeatureList_t selected
Definition: rf_algorithm.hxx:133
size_type size() const
Definition: array_vector.hxx:358
MultiArrayView< 2, int > variables
Definition: rf_algorithm.hxx:969
double operator()(Feature_t const &features, Response_t const &response)
Definition: rf_algorithm.hxx:101
RFErrorCallback(RandomForestOptions opt=RandomForestOptions())
Definition: rf_algorithm.hxx:93
Definition: rf_algorithm.hxx:116
detail::VisitorNode< A > create_visitor(A &a)
Definition: rf_visitors.hxx:344