37 #ifndef VIGRA_RF_COMMON_HXX
38 #define VIGRA_RF_COMMON_HXX
44 struct ClassificationTag
69 friend RF_DEFAULT& ::vigra::rf_default();
99 template<
class T,
class C>
104 static T & choose(T & t, C &)
111 class Value_Chooser<detail::RF_DEFAULT, C>
116 static C & choose(detail::RF_DEFAULT &, C & c)
133 static detail::RF_DEFAULT result;
176 double training_set_proportion_;
177 int training_set_size_;
178 int (*training_set_func_)(int);
180 training_set_calc_switch_;
182 bool sample_with_replacement_;
184 stratification_method_;
195 int (*mtry_func_)(int) ;
197 bool predict_weighted_;
199 int min_split_node_size_;
200 bool prepare_online_learning_;
204 typedef std::map<std::string, double_array> map_type;
206 int serialized_size()
const
215 #define COMPARE(field) result = result && (this->field == rhs.field);
216 COMPARE(training_set_proportion_);
217 COMPARE(training_set_size_);
218 COMPARE(training_set_calc_switch_);
219 COMPARE(sample_with_replacement_);
220 COMPARE(stratification_method_);
221 COMPARE(mtry_switch_);
223 COMPARE(tree_count_);
224 COMPARE(min_split_node_size_);
225 COMPARE(predict_weighted_);
232 return !(*
this == rhs_);
235 void unserialize(Iter
const & begin, Iter
const & end)
238 vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
239 "RandomForestOptions::unserialize():"
240 "wrong number of parameters");
241 #define PULL(item_, type_) item_ = type_(*iter); ++iter;
242 PULL(training_set_proportion_,
double);
243 PULL(training_set_size_,
int);
246 PULL(sample_with_replacement_, 0 != );
251 PULL(tree_count_,
int);
252 PULL(min_split_node_size_,
int);
253 PULL(predict_weighted_, 0 !=);
257 void serialize(Iter
const & begin, Iter
const & end)
const
260 vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
261 "RandomForestOptions::serialize():"
262 "wrong number of parameters");
263 #define PUSH(item_) *iter = double(item_); ++iter;
264 PUSH(training_set_proportion_);
265 PUSH(training_set_size_);
266 if(training_set_func_ != 0)
274 PUSH(training_set_calc_switch_);
275 PUSH(sample_with_replacement_);
276 PUSH(stratification_method_);
288 PUSH(min_split_node_size_);
289 PUSH(predict_weighted_);
293 void make_from_map(map_type & in)
295 #define PULL(item_, type_) item_ = type_(in[#item_][0]);
296 #define PULLBOOL(item_, type_) item_ = type_(in[#item_][0] > 0);
297 PULL(training_set_proportion_,
double);
298 PULL(training_set_size_,
int);
300 PULL(tree_count_,
int);
301 PULL(min_split_node_size_,
int);
302 PULLBOOL(sample_with_replacement_,
bool);
303 PULLBOOL(prepare_online_learning_,
bool);
304 PULLBOOL(predict_weighted_,
bool);
317 void make_map(map_type & in)
const
319 #define PUSH(item_, type_) in[#item_] = double_array(1, double(item_));
320 #define PUSHFUNC(item_, type_) in[#item_] = double_array(1, double(item_!=0));
321 PUSH(training_set_proportion_,
double);
322 PUSH(training_set_size_,
int);
324 PUSH(tree_count_,
int);
325 PUSH(min_split_node_size_,
int);
326 PUSH(sample_with_replacement_,
bool);
327 PUSH(prepare_online_learning_,
bool);
328 PUSH(predict_weighted_,
bool);
334 PUSHFUNC(mtry_func_,
int);
335 PUSHFUNC(training_set_func_,
int);
348 training_set_proportion_(1.0),
349 training_set_size_(0),
350 training_set_func_(0),
351 training_set_calc_switch_(RF_PROPORTIONAL),
352 sample_with_replacement_(true),
353 stratification_method_(RF_NONE),
354 mtry_switch_(RF_SQRT),
357 predict_weighted_(false),
359 min_split_node_size_(1),
360 prepare_online_learning_(false)
376 vigra_precondition(in == RF_EQUAL ||
377 in == RF_PROPORTIONAL ||
380 "RandomForestOptions::use_stratification()"
381 "input must be RF_EQUAL, RF_PROPORTIONAL,"
382 "RF_EXTERNAL or RF_NONE");
383 stratification_method_ = in;
389 prepare_online_learning_=in;
399 sample_with_replacement_ = in;
413 training_set_proportion_ = in;
414 training_set_calc_switch_ = RF_PROPORTIONAL;
425 training_set_size_ = in;
426 training_set_calc_switch_ = RF_CONST;
438 training_set_func_ = in;
439 training_set_calc_switch_ = RF_FUNCTION;
447 predict_weighted_ =
true;
462 vigra_precondition(in == RF_LOG ||
465 "RandomForestOptions()::features_per_node():"
466 "input must be of type RF_LOG or RF_SQRT");
480 mtry_switch_ = RF_CONST;
492 mtry_switch_ = RF_FUNCTION;
516 min_split_node_size_ = in;
524 enum Problem_t{REGRESSION, CLASSIFICATION, CHECKLATER};
537 template<
class LabelType =
double>
550 typedef std::map<std::string, double_array> map_type;
559 Problem_t problem_type_;
568 void to_classlabel(
int index, T & out)
const
570 out = T(classes[index]);
573 int to_classIndex(T index)
const
575 return std::find(classes.
begin(), classes.
end(), index) - classes.
begin();
578 #define EQUALS(field) field(rhs.field)
581 EQUALS(column_count_),
582 EQUALS(class_count_),
584 EQUALS(actual_mtry_),
585 EQUALS(actual_msample_),
586 EQUALS(problem_type_),
588 EQUALS(class_weights_),
589 EQUALS(is_weighted_),
591 EQUALS(response_size_)
593 std::back_insert_iterator<ArrayVector<Label_t> >
595 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
598 #define EQUALS(field) field(rhs.field)
602 EQUALS(column_count_),
603 EQUALS(class_count_),
605 EQUALS(actual_mtry_),
606 EQUALS(actual_msample_),
607 EQUALS(problem_type_),
609 EQUALS(class_weights_),
610 EQUALS(is_weighted_),
612 EQUALS(response_size_)
614 std::back_insert_iterator<ArrayVector<Label_t> >
616 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
620 #define EQUALS(field) (this->field = rhs.field);
623 EQUALS(column_count_);
624 EQUALS(class_count_);
626 EQUALS(actual_mtry_);
627 EQUALS(actual_msample_);
628 EQUALS(problem_type_);
630 EQUALS(is_weighted_);
632 EQUALS(response_size_)
633 class_weights_.clear();
634 std::back_insert_iterator<ArrayVector<
double> >
635 iter2(class_weights_);
636 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
638 std::back_insert_iterator<ArrayVector<
Label_t> >
640 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
647 EQUALS(column_count_);
648 EQUALS(class_count_);
650 EQUALS(actual_mtry_);
651 EQUALS(actual_msample_);
652 EQUALS(problem_type_);
654 EQUALS(is_weighted_);
656 EQUALS(response_size_)
657 class_weights_.clear();
658 std::back_insert_iterator<ArrayVector<
double> >
659 iter2(class_weights_);
660 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
662 std::back_insert_iterator<ArrayVector<
Label_t> >
664 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
670 bool operator==(ProblemSpec<T>
const & rhs)
673 #define COMPARE(field) result = result && (this->field == rhs.field);
674 COMPARE(column_count_);
675 COMPARE(class_count_);
677 COMPARE(actual_mtry_);
678 COMPARE(actual_msample_);
679 COMPARE(problem_type_);
680 COMPARE(is_weighted_);
683 COMPARE(class_weights_);
685 COMPARE(response_size_)
692 return !(*
this == rhs);
696 size_t serialized_size()
const
698 return 10 + class_count_ *int(is_weighted_+1);
703 void unserialize(Iter
const & begin, Iter
const & end)
706 vigra_precondition(end - begin >= 10,
707 "ProblemSpec::unserialize():"
708 "wrong number of parameters");
709 #define PULL(item_, type_) item_ = type_(*iter); ++iter;
710 PULL(column_count_,
int);
711 PULL(class_count_,
int);
713 vigra_precondition(end - begin >= 10 + class_count_,
714 "ProblemSpec::unserialize(): 1");
715 PULL(row_count_,
int);
716 PULL(actual_mtry_,
int);
717 PULL(actual_msample_,
int);
718 PULL(problem_type_, Problem_t);
719 PULL(is_weighted_,
int);
721 PULL(precision_,
double);
722 PULL(response_size_,
int);
725 vigra_precondition(end - begin == 10 + 2*class_count_,
726 "ProblemSpec::unserialize(): 2");
727 class_weights_.insert(class_weights_.end(),
729 iter + class_count_);
730 iter += class_count_;
732 classes.insert(classes.end(), iter, end);
738 void serialize(Iter
const & begin, Iter
const & end)
const
741 vigra_precondition(end - begin == serialized_size(),
742 "RandomForestOptions::serialize():"
743 "wrong number of parameters");
744 #define PUSH(item_) *iter = double(item_); ++iter;
749 PUSH(actual_msample_);
754 PUSH(response_size_);
757 std::copy(class_weights_.begin(),
758 class_weights_.end(),
760 iter += class_count_;
762 std::copy(classes.begin(),
768 void make_from_map(map_type & in)
770 #define PULL(item_, type_) item_ = type_(in[#item_][0]);
771 PULL(column_count_,
int);
772 PULL(class_count_,
int);
773 PULL(row_count_,
int);
774 PULL(actual_mtry_,
int);
775 PULL(actual_msample_,
int);
776 PULL(problem_type_, (Problem_t)
int);
777 PULL(is_weighted_,
int);
779 PULL(precision_,
double);
780 PULL(response_size_,
int);
781 class_weights_ = in[
"class_weights_"];
784 void make_map(map_type & in)
const
786 #define PUSH(item_) in[#item_] = double_array(1, double(item_));
791 PUSH(actual_msample_);
796 PUSH(response_size_);
797 in["class_weights_"] = class_weights_;
809 problem_type_(CHECKLATER),
827 template<
class C_Iter>
831 int size = end-begin;
832 for(
int k=0; k<size; ++k, ++begin)
833 classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin));
843 template<
class W_Iter>
846 class_weights_.clear();
847 class_weights_.insert(class_weights_.end(), begin, end);
858 class_weights_.clear();
863 problem_type_ = CHECKLATER;
864 is_weighted_ =
false;
888 int min_split_node_size_;
892 : min_split_node_size_(opt.min_split_node_size_)
896 void set_external_parameters(
ProblemSpec<T>const &,
int = 0,
bool =
false)
899 template<
class Region>
900 bool operator()(Region& region)
902 return region.size() < min_split_node_size_;
905 template<
class WeightIter,
class T,
class C>
915 #endif //VIGRA_RF_COMMON_HXX
RandomForestOptions & features_per_node(RF_OptionTag in)
use built in mapping to calculate mtry
Definition: rf_common.hxx:460
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag
Definition: rf_common.hxx:131
RandomForestOptions & samples_per_tree(double in)
specify the fraction of the total number of samples used per tree for learning.
Definition: rf_common.hxx:411
RandomForestOptions & features_per_node(int(*in)(int))
use a external function to calculate mtry
Definition: rf_common.hxx:489
const_iterator begin() const
Definition: array_vector.hxx:223
RandomForestOptions & samples_per_tree(int in)
directly specify the number of samples per tree
Definition: rf_common.hxx:423
RandomForestOptions & samples_per_tree(int(*in)(int))
use external function to calculate the number of samples each tree should be learnt with...
Definition: rf_common.hxx:436
problem specification class for the random forest.
Definition: rf_common.hxx:538
LabelType Label_t
problem class
Definition: rf_common.hxx:547
RandomForestOptions & min_split_node_size(int in)
Number of examples required for a node to be split.
Definition: rf_common.hxx:514
RandomForestOptions & features_per_node(int in)
Set mtry to a constant value.
Definition: rf_common.hxx:477
Standard early stopping criterion.
Definition: rf_common.hxx:885
ProblemSpec & classes_(C_Iter begin, C_Iter end)
supply with class labels -
Definition: rf_common.hxx:828
RandomForestOptions()
create a RandomForestOptions object with default initialisation.
Definition: rf_common.hxx:346
ProblemSpec & class_weights(W_Iter begin, W_Iter end)
supply with class weights -
Definition: rf_common.hxx:844
RandomForestOptions & sample_with_replacement(bool in)
sample from training population with or without replacement?
Definition: rf_common.hxx:397
RandomForestOptions & predict_weighted()
weight each tree with number of samples in that node
Definition: rf_common.hxx:445
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:704
Options object for the random forest.
Definition: rf_common.hxx:170
RandomForestOptions & use_stratification(RF_OptionTag in)
specify stratification strategy
Definition: rf_common.hxx:374
RandomForestOptions & tree_count(unsigned int in)
Definition: rf_common.hxx:500
const_iterator end() const
Definition: array_vector.hxx:237
ProblemSpec()
set default values (-> values not set)
Definition: rf_common.hxx:803
RF_OptionTag
Definition: rf_common.hxx:140