37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
50 #include "metaprogramming.hxx"
52 #include "functorexpression.hxx"
53 #include "random_forest/rf_common.hxx"
54 #include "random_forest/rf_nodeproxy.hxx"
55 #include "random_forest/rf_split.hxx"
56 #include "random_forest/rf_decisionTree.hxx"
57 #include "random_forest/rf_visitors.hxx"
58 #include "random_forest/rf_region.hxx"
59 #include "sampling.hxx"
60 #include "random_forest/rf_preprocessing.hxx"
61 #include "random_forest/rf_online_prediction_set.hxx"
62 #include "random_forest/rf_earlystopping.hxx"
63 #include "random_forest/rf_ridge_split.hxx"
83 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
85 SamplerOptions return_opt;
87 return_opt.
stratified(RF_opt.stratification_method_ == RF_EQUAL);
146 template <
class LabelType =
double ,
class PreprocessorTag = ClassificationTag >
153 typedef detail::DecisionTree DecisionTree_t;
160 typedef LabelType LabelT;
227 template<
class TopologyIterator,
class ParameterIterator>
229 TopologyIterator topology_begin,
230 ParameterIterator parameter_begin,
234 trees_(treeCount, DecisionTree_t(problem_spec)),
235 ext_param_(problem_spec),
241 for(
int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
243 trees_[k].topology_ = *topology_begin;
244 trees_[k].parameters_ = *parameter_begin;
262 vigra_precondition(ext_param_.used() ==
true,
263 "RandomForest::ext_param(): "
264 "Random forest has not been trained yet.");
281 vigra_precondition(ext_param_.used() ==
false,
282 "RandomForest::set_ext_param():"
283 "Random forest has been trained! Call reset()"
284 "before specifying new extrinsic parameters.");
308 DecisionTree_t
const &
tree(
int index)
const
310 return trees_[index];
315 DecisionTree_t &
tree(
int index)
317 return trees_[index];
325 return ext_param_.column_count_;
336 return ext_param_.column_count_;
344 return ext_param_.class_count_;
351 return options_.tree_count_;
392 template <
class U,
class C1,
403 Random_t
const & random);
405 template <
class U,
class C1,
426 template <
class U,
class C1,
class U2,
class C2,
class Visitor_t>
427 void learn( MultiArrayView<2, U, C1>
const & features,
428 MultiArrayView<2, U2,C2>
const & labels,
438 template <
class U,
class C1,
class U2,
class C2,
439 class Visitor_t,
class Split_t>
440 void learn( MultiArrayView<2, U, C1>
const & features,
441 MultiArrayView<2, U2,C2>
const & labels,
470 template <
class U,
class C1,
class U2,
class C2>
482 template<
class U,
class C1,
495 bool adjust_thresholds=
false);
497 template <
class U,
class C1,
class U2,
class C2>
502 onlineLearn(features,
512 template<
class U,
class C1,
518 void reLearnTree(MultiArrayView<2,U,C1>
const & features,
519 MultiArrayView<2,U2,C2>
const & response,
526 template<
class U,
class C1,
class U2,
class C2>
527 void reLearnTree(MultiArrayView<2, U, C1>
const & features,
528 MultiArrayView<2, U2, C2>
const & labels,
531 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
561 template <
class U,
class C,
class Stop>
562 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features, Stop & stop)
const;
564 template <
class U,
class C>
565 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features)
575 template <
class U,
class C>
576 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features,
577 ArrayVectorView<double> prior)
const;
589 template <
class U,
class C1,
class T,
class C2>
593 vigra_precondition(features.
shape(0) == labels.
shape(0),
594 "RandomForest::predictLabels(): Label array has wrong size.");
595 for(
int k=0; k<features.
shape(0); ++k)
597 vigra_precondition(!detail::contains_nan(
rowVector(features, k)),
598 "RandomForest::predictLabels(): NaN in feature matrix.");
613 template <
class U,
class C1,
class T,
class C2>
616 LabelType nanLabel)
const
618 vigra_precondition(features.
shape(0) == labels.
shape(0),
619 "RandomForest::predictLabels(): Label array has wrong size.");
620 for(
int k=0; k<features.
shape(0); ++k)
622 if(detail::contains_nan(
rowVector(features, k)))
623 labels(k,0) = nanLabel;
638 template <
class U,
class C1,
class T,
class C2,
class Stop>
643 vigra_precondition(features.
shape(0) == labels.
shape(0),
644 "RandomForest::predictLabels(): Label array has wrong size.");
645 for(
int k=0; k<features.
shape(0); ++k)
660 template <
class U,
class C1,
class T,
class C2,
class Stop>
664 template <
class T1,
class T2,
class C>
674 template <
class U,
class C1,
class T,
class C2>
681 template <
class U,
class C1,
class T,
class C2>
691 template <
class LabelType,
class PreprocessorTag>
692 template<
class U,
class C1,
698 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1>
const & features,
699 MultiArrayView<2,U2,C2>
const & response,
705 bool adjust_thresholds)
707 online_visitor_.activate();
708 online_visitor_.adjust_thresholds=adjust_thresholds;
712 typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
713 typedef UniformIntRandomFunctor<Random_t>
720 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
721 Default_Stop_t default_stop(options_);
722 typename RF_CHOOSER(Stop_t)::type stop
723 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
724 Default_Split_t default_split;
725 typename RF_CHOOSER(Split_t)::type split
726 = RF_CHOOSER(Split_t)::choose(split_, default_split);
727 rf::visitors::StopVisiting stopvisiting;
728 typedef rf::visitors::detail::VisitorNode
729 <rf::visitors::OnlineLearnVisitor,
730 typename RF_CHOOSER(Visitor_t)::type>
733 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
735 vigra_precondition(options_.prepare_online_learning_,
"onlineLearn: online learning must be enabled on RandomForest construction");
741 ext_param_.class_count_=0;
742 Preprocessor_t preprocessor( features, response,
743 options_, ext_param_);
746 RandFunctor_t randint ( random);
749 split.set_external_parameters(ext_param_);
750 stop.set_external_parameters(ext_param_);
754 PoissonSampler<RandomTT800> poisson_sampler(1.0,
vigra::Int32(new_start_index),
vigra::Int32(ext_param().row_count_));
760 for(
int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
762 online_visitor_.tree_id=ii;
763 poisson_sampler.sample();
764 std::map<int,int> leaf_parents;
765 leaf_parents.clear();
767 for(
int s=0;s<poisson_sampler.numOfSamples();++s)
769 int sample=poisson_sampler[s];
770 online_visitor_.current_label=preprocessor.response()(sample,0);
771 online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
772 int leaf=trees_[ii].getToLeaf(
rowVector(features,sample),online_visitor_);
776 online_visitor_.add_to_index_list(ii,leaf,sample);
779 if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
781 leaf_parents[leaf]=online_visitor_.last_node_id;
786 std::map<int,int>::iterator leaf_iterator;
787 for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
789 int leaf=leaf_iterator->first;
790 int parent=leaf_iterator->second;
791 int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
792 ArrayVector<Int32> indeces;
794 indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
795 StackEntry_t stack_entry(indeces.begin(),
797 ext_param_.class_count_);
802 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
804 stack_entry.leftParent=parent;
808 vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,
"last_node_id seems to be wrong");
809 stack_entry.rightParent=parent;
813 trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
815 online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
828 online_visitor_.deactivate();
831 template<
class LabelType,
class PreprocessorTag>
832 template<
class U,
class C1,
853 ext_param_.class_count_=0;
861 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
863 typename RF_CHOOSER(Stop_t)::type stop
864 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
866 typename RF_CHOOSER(Split_t)::type split
867 = RF_CHOOSER(Split_t)::choose(split_, default_split);
871 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
873 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
875 vigra_precondition(options_.prepare_online_learning_,
"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
876 online_visitor_.activate();
879 RandFunctor_t randint ( random);
885 Preprocessor_t preprocessor( features, response,
886 options_, ext_param_);
889 split.set_external_parameters(ext_param_);
890 stop.set_external_parameters(ext_param_);
897 preprocessor.strata().end(),
898 detail::make_sampler_opt(options_)
899 .sampleSize(ext_param().actual_msample_),
906 first_stack_entry( sampler.sampledIndices().begin(),
907 sampler.sampledIndices().end(),
908 ext_param_.class_count_);
910 .set_oob_range( sampler.oobIndices().begin(),
911 sampler.oobIndices().end());
912 online_visitor_.reset_tree(treeId);
913 online_visitor_.tree_id=treeId;
914 trees_[treeId].reset();
916 .learn( preprocessor.features(),
917 preprocessor.response(),
924 .visit_after_tree( *
this,
930 online_visitor_.deactivate();
933 template <
class LabelType,
class PreprocessorTag>
934 template <
class U,
class C1,
946 Random_t
const & random)
957 vigra_precondition(features.
shape(0) == response.
shape(0),
958 "RandomForest::learn(): shape mismatch between features and response.");
965 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
967 typename RF_CHOOSER(Stop_t)::type stop
968 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
970 typename RF_CHOOSER(Split_t)::type split
971 = RF_CHOOSER(Split_t)::choose(split_, default_split);
975 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
977 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
979 if(options_.prepare_online_learning_)
980 online_visitor_.activate();
982 online_visitor_.deactivate();
986 RandFunctor_t randint ( random);
993 Preprocessor_t preprocessor( features, response,
994 options_, ext_param_);
997 split.set_external_parameters(ext_param_);
998 stop.set_external_parameters(ext_param_);
1002 trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
1005 preprocessor.strata().end(),
1006 detail::make_sampler_opt(options_)
1007 .sampleSize(ext_param().actual_msample_),
1010 visitor.visit_at_beginning(*
this, preprocessor);
1013 for(
int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
1019 first_stack_entry( sampler.sampledIndices().begin(),
1020 sampler.sampledIndices().end(),
1021 ext_param_.class_count_);
1023 .set_oob_range( sampler.oobIndices().begin(),
1024 sampler.oobIndices().end());
1026 .learn( preprocessor.features(),
1027 preprocessor.response(),
1034 .visit_after_tree( *
this,
1041 visitor.visit_at_end(*
this, preprocessor);
1043 online_visitor_.deactivate();
1049 template <
class LabelType,
class Tag>
1050 template <
class U,
class C,
class Stop>
1054 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1055 "RandomForestn::predictLabel():"
1056 " Too few columns in feature matrix.");
1057 vigra_precondition(
rowCount(features) == 1,
1058 "RandomForestn::predictLabel():"
1059 " Feature matrix must have a singlerow.");
1062 predictProbabilities(features, probabilities, stop);
1063 ext_param_.to_classlabel(
argMax(probabilities), d);
1069 template <
class LabelType,
class PreprocessorTag>
1070 template <
class U,
class C>
1075 using namespace functor;
1076 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1077 "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1078 vigra_precondition(
rowCount(features) == 1,
1079 "RandomForestn::predictLabel():"
1080 " Feature matrix must have a single row.");
1081 Matrix<double> prob(1,ext_param_.class_count_);
1082 predictProbabilities(features, prob);
1083 std::transform( prob.begin(), prob.end(),
1084 priors.
begin(), prob.begin(),
1087 ext_param_.to_classlabel(
argMax(prob), d);
1091 template<
class LabelType,
class PreprocessorTag>
1092 template <
class T1,
class T2,
class C>
1101 "RandomFroest::predictProbabilities():"
1102 " Feature matrix and probability matrix size mismatch.");
1105 vigra_precondition(
columnCount(predictionSet.features) >= ext_param_.column_count_,
1106 "RandomForestn::predictProbabilities():"
1107 " Too few columns in feature matrix.");
1109 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1110 "RandomForestn::predictProbabilities():"
1111 " Probability matrix must have as many columns as there are classes.");
1114 std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1117 for(
int k=0; k<options_.tree_count_; ++k)
1119 set_id=(set_id+1) % predictionSet.indices[0].size();
1120 typedef std::set<SampleRange<T1> > my_set;
1121 typedef typename my_set::iterator set_it;
1124 std::vector<std::pair<int,set_it> > stack;
1126 for(set_it i=predictionSet.ranges[set_id].begin();
1127 i!=predictionSet.ranges[set_id].end();++i)
1128 stack.push_back(std::pair<int,set_it>(2,i));
1130 int num_decisions=0;
1131 while(!stack.empty())
1133 set_it range=stack.back().second;
1134 int index=stack.back().first;
1138 if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1141 trees_[k].parameters_,
1142 index).prob_begin();
1143 for(
int i=range->start;i!=range->end;++i)
1146 for(
int l=0; l<ext_param_.class_count_; ++l)
1148 prob(predictionSet.indices[set_id][i], l) +=
static_cast<T2
>(weights[l]);
1150 totalWeights[predictionSet.indices[set_id][i]] +=
static_cast<T1
>(weights[l]);
1157 if(trees_[k].topology_[index]!=i_ThresholdNode)
1159 throw std::runtime_error(
"predicting with online prediction sets is only supported for RFs with threshold nodes");
1161 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1162 if(range->min_boundaries[node.column()]>=node.threshold())
1165 stack.push_back(std::pair<int,set_it>(node.child(1),range));
1168 if(range->max_boundaries[node.column()]<node.threshold())
1171 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1175 SampleRange<T1> new_range=*range;
1176 new_range.min_boundaries[node.column()]=FLT_MAX;
1177 range->max_boundaries[node.column()]=-FLT_MAX;
1178 new_range.start=new_range.end=range->end;
1180 while(i!=range->end)
1183 if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1185 new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1186 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1189 std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1194 range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1195 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1200 if(range->start==range->end)
1202 predictionSet.ranges[set_id].erase(range);
1206 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1209 if(new_range.start!=new_range.end)
1211 std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1212 stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1216 predictionSet.cumulativePredTime[k]=num_decisions;
1218 for(
unsigned int i=0;i<totalWeights.size();++i)
1222 for(
int l=0; l<ext_param_.class_count_; ++l)
1225 prob(i, l) /= totalWeights[i];
1227 assert(test==totalWeights[i]);
1228 assert(totalWeights[i]>0.0);
1232 template <
class LabelType,
class PreprocessorTag>
1233 template <
class U,
class C1,
class T,
class C2,
class Stop_t>
1236 MultiArrayView<2, T, C2> & prob,
1237 Stop_t & stop_)
const
1243 "RandomForestn::predictProbabilities():"
1244 " Feature matrix and probability matrix size mismatch.");
1248 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1249 "RandomForestn::predictProbabilities():"
1250 " Too few columns in feature matrix.");
1252 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1253 "RandomForestn::predictProbabilities():"
1254 " Probability matrix must have as many columns as there are classes.");
1256 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1257 Default_Stop_t default_stop(options_);
1258 typename RF_CHOOSER(Stop_t)::type & stop
1259 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1261 stop.set_external_parameters(ext_param_, tree_count());
1262 prob.init(NumericTraits<T>::zero());
1272 for(
int row=0; row <
rowCount(features); ++row)
1274 MultiArrayView<2, U, StridedArrayTag> currentRow(
rowVector(features, row));
1278 if(detail::contains_nan(currentRow))
1284 ArrayVector<double>::const_iterator weights;
1287 double totalWeight = 0.0;
1290 for(
int k=0; k<options_.tree_count_; ++k)
1293 weights = trees_[k ].predict(currentRow);
1296 int weighted = options_.predict_weighted_;
1297 for(
int l=0; l<ext_param_.class_count_; ++l)
1299 double cur_w = weights[l] * (weighted * (*(weights-1))
1301 prob(row, l) +=
static_cast<T
>(cur_w);
1303 totalWeight += cur_w;
1305 if(stop.after_prediction(weights,
1315 for(
int l=0; l< ext_param_.class_count_; ++l)
1317 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1323 template <
class LabelType,
class PreprocessorTag>
1324 template <
class U,
class C1,
class T,
class C2>
1325 void RandomForest<LabelType, PreprocessorTag>
1326 ::predictRaw(MultiArrayView<2, U, C1>
const & features,
1327 MultiArrayView<2, T, C2> & prob)
const
1333 "RandomForestn::predictProbabilities():"
1334 " Feature matrix and probability matrix size mismatch.");
1338 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1339 "RandomForestn::predictProbabilities():"
1340 " Too few columns in feature matrix.");
1342 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1343 "RandomForestn::predictProbabilities():"
1344 " Probability matrix must have as many columns as there are classes.");
1346 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1347 prob.init(NumericTraits<T>::zero());
1357 for(
int row=0; row <
rowCount(features); ++row)
1359 ArrayVector<double>::const_iterator weights;
1362 double totalWeight = 0.0;
1365 for(
int k=0; k<options_.tree_count_; ++k)
1368 weights = trees_[k ].predict(
rowVector(features, row));
1371 int weighted = options_.predict_weighted_;
1372 for(
int l=0; l<ext_param_.class_count_; ++l)
1374 double cur_w = weights[l] * (weighted * (*(weights-1))
1376 prob(row, l) +=
static_cast<T
>(cur_w);
1378 totalWeight += cur_w;
1382 prob/= options_.tree_count_;
1388 #include "random_forest/rf_algorithm.hxx"
1389 #endif // VIGRA_RANDOM_FOREST_HXX
Definition: rf_region.hxx:57
void set_ext_param(ProblemSpec_t const &in)
set external parameters
Definition: random_forest.hxx:278
int class_count() const
return number of classes used while training.
Definition: random_forest.hxx:342
Definition: rf_nodeproxy.hxx:626
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag
Definition: rf_common.hxx:131
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:671
Definition: rf_preprocessing.hxx:63
int feature_count() const
return number of features used while training.
Definition: random_forest.hxx:323
int column_count() const
return number of features used while training.
Definition: random_forest.hxx:334
Create random samples from a sequence of indices.
Definition: sampling.hxx:232
const difference_type & shape() const
Definition: multi_array.hxx:1648
void sample()
Definition: sampling.hxx:467
Definition: rf_split.hxx:993
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, LabelType nanLabel) const
predict multiple labels with given features
Definition: random_forest.hxx:614
const_iterator begin() const
Definition: array_vector.hxx:223
problem specification class for the random forest.
Definition: rf_common.hxx:538
RandomForest(Options_t const &options=Options_t(), ProblemSpec_t const &ext_param=ProblemSpec_t())
default constructor
Definition: random_forest.hxx:197
Standard early stopping criterion.
Definition: rf_common.hxx:885
ProblemSpec_t const & ext_param() const
return external parameters for viewing
Definition: random_forest.hxx:260
DecisionTree_t & tree(int index)
access trees
Definition: random_forest.hxx:315
DecisionTree_t const & tree(int index) const
access const trees
Definition: random_forest.hxx:308
Options_t & set_options()
access random forest options
Definition: random_forest.hxx:291
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels)
learn on data with default configuration
Definition: random_forest.hxx:471
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const
predict the class probabilities for multiple labels
Definition: random_forest.hxx:675
Random forest version 2 (see also vigra::rf3::RandomForest for version 3)
Definition: random_forest.hxx:147
Options_t const & options() const
access const random forest options
Definition: random_forest.hxx:301
Definition: rf_visitors.hxx:254
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
LabelType predictLabel(MultiArrayView< 2, U, C >const &features, Stop &stop) const
predict a label given a feature.
Definition: random_forest.hxx:1052
Definition: rf_visitors.hxx:583
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels) const
predict multiple labels with given features
Definition: random_forest.hxx:590
SamplerOptions & stratified(bool in=true)
Draw equally many samples from each "stratum". A stratum is a group of like entities, e.g. pixels belonging to the same object class. This is useful to create balanced samples when the class probabilities are very unbalanced (e.g. when there are many background and few foreground pixels). Stratified sampling thus avoids that a trained classifier is biased towards the majority class.
Definition: sampling.hxx:141
SamplerOptions & withReplacement(bool in=true)
Sample from training population with replacement.
Definition: sampling.hxx:83
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:697
int tree_count() const
return number of trees
Definition: random_forest.hxx:349
void reLearnTree(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, int treeId, Visitor_t visitor_, Split_t split_, Stop_t stop_, Random_t &random)
Definition: random_forest.hxx:838
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:684
RandomForest(int treeCount, TopologyIterator topology_begin, ParameterIterator parameter_begin, ProblemSpec_t const &problem_spec, Options_t const &options=Options_t())
Create RF from external source.
Definition: random_forest.hxx:228
Definition: random.hxx:336
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:704
Options object for the random forest.
Definition: rf_common.hxx:170
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, Stop &stop) const
predict multiple labels with given features
Definition: random_forest.hxx:639
MultiArrayView & init(const U &init)
Definition: multi_array.hxx:1206
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator
Definition: random_forest.hxx:941
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob, Stop &stop) const
predict the class probabilities for multiple labels
Definition: rf_visitors.hxx:234