[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_common.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 
37 #ifndef VIGRA_RF_COMMON_HXX
38 #define VIGRA_RF_COMMON_HXX
39 
40 namespace vigra
41 {
42 
43 
44 struct ClassificationTag
45 {};
46 
47 struct RegressionTag
48 {};
49 
50 namespace detail
51 {
52  class RF_DEFAULT;
53 }
54 inline detail::RF_DEFAULT& rf_default();
55 namespace detail
56 {
57 
58 /* \brief singleton default tag class -
59  *
60  * use the rf_default() factory function to use the tag.
61  * \sa RandomForest<>::learn();
62  */
63 class RF_DEFAULT
64 {
65  private:
66  RF_DEFAULT()
67  {}
68  public:
69  friend RF_DEFAULT& ::vigra::rf_default();
70 
71  /** ok workaround for automatic choice of the decisiontree
72  * stackentry.
73  */
74 };
75 
76 /* \brief chooses between default type and type supplied
77  *
78  * This is an internal class and you shouldn't really care about it.
79  * Just pass on used in RandomForest.learn()
80  * Usage:
81  *\code
82  * // example: use container type supplied by user or ArrayVector if
83  * // rf_default() was specified as argument;
84  * template<class Container_t>
85  * void do_some_foo(Container_t in)
86  * {
87  * typedef ArrayVector<int> Default_Container_t;
88  * Default_Container_t default_value;
89  * Value_Chooser<Container_t, Default_Container_t>
90  * choose(in, default_value);
91  *
92  * // if the user didn't care and the in was of type
93  * // RF_DEFAULT then default_value is used.
94  * do_some_more_foo(choose.value());
95  * }
96  * Value_Chooser choose_val<Type, Default_Type>
97  *\endcode
98  */
99 template<class T, class C>
100 class Value_Chooser
101 {
102 public:
103  typedef T type;
104  static T & choose(T & t, C &)
105  {
106  return t;
107  }
108 };
109 
110 template<class C>
111 class Value_Chooser<detail::RF_DEFAULT, C>
112 {
113 public:
114  typedef C type;
115 
116  static C & choose(detail::RF_DEFAULT &, C & c)
117  {
118  return c;
119  }
120 };
121 
122 
123 
124 
125 } //namespace detail
126 
127 
128 /**\brief factory function to return a RF_DEFAULT tag
129  * \sa RandomForest<>::learn()
130  */
131 detail::RF_DEFAULT& rf_default()
132 {
133  static detail::RF_DEFAULT result;
134  return result;
135 }
136 
137 /** tags used with the RandomForestOptions class
138  * \sa RF_Traits::Option_t
139  */
140 enum RF_OptionTag { RF_EQUAL,
141  RF_PROPORTIONAL,
142  RF_EXTERNAL,
143  RF_NONE,
144  RF_FUNCTION,
145  RF_LOG,
146  RF_SQRT,
147  RF_CONST,
148  RF_ALL};
149 
150 
151 /** \addtogroup MachineLearning
152 **/
153 //@{
154 
155 /**\brief Options object for the random forest
156  *
157  * usage:
158  * RandomForestOptions a = RandomForestOptions()
159  * .param1(value1)
160  * .param2(value2)
161  * ...
162  *
163  * This class only contains options/parameters that are not problem
164  * dependent. The ProblemSpec class contains methods to set class weights
165  * if necessary.
166  *
167  * Note that the return value of all methods is *this which makes
168  * concatenating of options as above possible.
169  */
171 {
172  public:
173  /**\name sampling options*/
174  /*\{*/
175  // look at the member access functions for documentation
176  double training_set_proportion_;
177  int training_set_size_;
178  int (*training_set_func_)(int);
180  training_set_calc_switch_;
181 
182  bool sample_with_replacement_;
184  stratification_method_;
185 
186 
187  /**\name general random forest options
188  *
189  * these usually will be used by most split functors and
190  * stopping predicates
191  */
192  /*\{*/
193  RF_OptionTag mtry_switch_;
194  int mtry_;
195  int (*mtry_func_)(int) ;
196 
197  bool predict_weighted_;
198  int tree_count_;
199  int min_split_node_size_;
200  bool prepare_online_learning_;
201  /*\}*/
202 
204  typedef std::map<std::string, double_array> map_type;
205 
206  int serialized_size() const
207  {
208  return 12;
209  }
210 
211 
212  bool operator==(RandomForestOptions & rhs) const
213  {
214  bool result = true;
215  #define COMPARE(field) result = result && (this->field == rhs.field);
216  COMPARE(training_set_proportion_);
217  COMPARE(training_set_size_);
218  COMPARE(training_set_calc_switch_);
219  COMPARE(sample_with_replacement_);
220  COMPARE(stratification_method_);
221  COMPARE(mtry_switch_);
222  COMPARE(mtry_);
223  COMPARE(tree_count_);
224  COMPARE(min_split_node_size_);
225  COMPARE(predict_weighted_);
226  #undef COMPARE
227 
228  return result;
229  }
230  bool operator!=(RandomForestOptions & rhs_) const
231  {
232  return !(*this == rhs_);
233  }
234  template<class Iter>
235  void unserialize(Iter const & begin, Iter const & end)
236  {
237  Iter iter = begin;
238  vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
239  "RandomForestOptions::unserialize():"
240  "wrong number of parameters");
241  #define PULL(item_, type_) item_ = type_(*iter); ++iter;
242  PULL(training_set_proportion_, double);
243  PULL(training_set_size_, int);
244  ++iter; //PULL(training_set_func_, double);
245  PULL(training_set_calc_switch_, (RF_OptionTag)int);
246  PULL(sample_with_replacement_, 0 != );
247  PULL(stratification_method_, (RF_OptionTag)int);
248  PULL(mtry_switch_, (RF_OptionTag)int);
249  PULL(mtry_, int);
250  ++iter; //PULL(mtry_func_, double);
251  PULL(tree_count_, int);
252  PULL(min_split_node_size_, int);
253  PULL(predict_weighted_, 0 !=);
254  #undef PULL
255  }
256  template<class Iter>
257  void serialize(Iter const & begin, Iter const & end) const
258  {
259  Iter iter = begin;
260  vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
261  "RandomForestOptions::serialize():"
262  "wrong number of parameters");
263  #define PUSH(item_) *iter = double(item_); ++iter;
264  PUSH(training_set_proportion_);
265  PUSH(training_set_size_);
266  if(training_set_func_ != 0)
267  {
268  PUSH(1);
269  }
270  else
271  {
272  PUSH(0);
273  }
274  PUSH(training_set_calc_switch_);
275  PUSH(sample_with_replacement_);
276  PUSH(stratification_method_);
277  PUSH(mtry_switch_);
278  PUSH(mtry_);
279  if(mtry_func_ != 0)
280  {
281  PUSH(1);
282  }
283  else
284  {
285  PUSH(0);
286  }
287  PUSH(tree_count_);
288  PUSH(min_split_node_size_);
289  PUSH(predict_weighted_);
290  #undef PUSH
291  }
292 
293  void make_from_map(map_type & in) // -> const: .operator[] -> .find
294  {
295  #define PULL(item_, type_) item_ = type_(in[#item_][0]);
296  #define PULLBOOL(item_, type_) item_ = type_(in[#item_][0] > 0);
297  PULL(training_set_proportion_,double);
298  PULL(training_set_size_, int);
299  PULL(mtry_, int);
300  PULL(tree_count_, int);
301  PULL(min_split_node_size_, int);
302  PULLBOOL(sample_with_replacement_, bool);
303  PULLBOOL(prepare_online_learning_, bool);
304  PULLBOOL(predict_weighted_, bool);
305 
306  PULL(training_set_calc_switch_, (RF_OptionTag)(int));
307 
308  PULL(stratification_method_, (RF_OptionTag)(int));
309  PULL(mtry_switch_, (RF_OptionTag)(int));
310 
311  /*don't pull*/
312  //PULL(mtry_func_!=0, int);
313  //PULL(training_set_func,int);
314  #undef PULL
315  #undef PULLBOOL
316  }
317  void make_map(map_type & in) const
318  {
319  #define PUSH(item_, type_) in[#item_] = double_array(1, double(item_));
320  #define PUSHFUNC(item_, type_) in[#item_] = double_array(1, double(item_!=0));
321  PUSH(training_set_proportion_,double);
322  PUSH(training_set_size_, int);
323  PUSH(mtry_, int);
324  PUSH(tree_count_, int);
325  PUSH(min_split_node_size_, int);
326  PUSH(sample_with_replacement_, bool);
327  PUSH(prepare_online_learning_, bool);
328  PUSH(predict_weighted_, bool);
329 
330  PUSH(training_set_calc_switch_, RF_OptionTag);
331  PUSH(stratification_method_, RF_OptionTag);
332  PUSH(mtry_switch_, RF_OptionTag);
333 
334  PUSHFUNC(mtry_func_, int);
335  PUSHFUNC(training_set_func_,int);
336  #undef PUSH
337  #undef PUSHFUNC
338  }
339 
340 
341  /**\brief create a RandomForestOptions object with default initialisation.
342  *
343  * look at the other member functions for more information on default
344  * values
345  */
347  :
348  training_set_proportion_(1.0),
349  training_set_size_(0),
350  training_set_func_(0),
351  training_set_calc_switch_(RF_PROPORTIONAL),
352  sample_with_replacement_(true),
353  stratification_method_(RF_NONE),
354  mtry_switch_(RF_SQRT),
355  mtry_(0),
356  mtry_func_(0),
357  predict_weighted_(false),
358  tree_count_(255),
359  min_split_node_size_(1),
360  prepare_online_learning_(false)
361  {}
362 
363  /**\brief specify stratification strategy
364  *
365  * default: RF_NONE
366  * possible values: RF_EQUAL, RF_PROPORTIONAL,
367  * RF_EXTERNAL, RF_NONE
368  * RF_EQUAL: get equal amount of samples per class.
369  * RF_PROPORTIONAL: sample proportional to fraction of class samples
370  * in population
371  * RF_EXTERNAL: strata_weights_ field of the ProblemSpec_t object
372  * has been set externally. (defunct)
373  */
375  {
376  vigra_precondition(in == RF_EQUAL ||
377  in == RF_PROPORTIONAL ||
378  in == RF_EXTERNAL ||
379  in == RF_NONE,
380  "RandomForestOptions::use_stratification()"
381  "input must be RF_EQUAL, RF_PROPORTIONAL,"
382  "RF_EXTERNAL or RF_NONE");
383  stratification_method_ = in;
384  return *this;
385  }
386 
387  RandomForestOptions & prepare_online_learning(bool in)
388  {
389  prepare_online_learning_=in;
390  return *this;
391  }
392 
393  /**\brief sample from training population with or without replacement?
394  *
395  * <br> Default: true
396  */
398  {
399  sample_with_replacement_ = in;
400  return *this;
401  }
402 
403  /**\brief specify the fraction of the total number of samples
404  * used per tree for learning.
405  *
406  * This value should be in [0.0 1.0] if sampling without
407  * replacement has been specified.
408  *
409  * <br> default : 1.0
410  */
412  {
413  training_set_proportion_ = in;
414  training_set_calc_switch_ = RF_PROPORTIONAL;
415  return *this;
416  }
417 
418  /**\brief directly specify the number of samples per tree
419  *
420  * This value should not be higher than the total number of
421  * samples if sampling without replacement has been specified.
422  */
424  {
425  training_set_size_ = in;
426  training_set_calc_switch_ = RF_CONST;
427  return *this;
428  }
429 
430  /**\brief use external function to calculate the number of samples each
431  * tree should be learnt with.
432  *
433  * \param in function pointer that takes the number of rows in the
434  * learning data and outputs the number samples per tree.
435  */
437  {
438  training_set_func_ = in;
439  training_set_calc_switch_ = RF_FUNCTION;
440  return *this;
441  }
442 
443  /**\brief weight each tree with number of samples in that node
444  */
446  {
447  predict_weighted_ = true;
448  return *this;
449  }
450 
451  /**\brief use built in mapping to calculate mtry
452  *
453  * Use one of the built in mappings to calculate mtry from the number
454  * of columns in the input feature data.
455  * \param in possible values:
456  * - RF_LOG (the number of features considered for each split is \f$ 1+\lfloor \log(n_f)/\log(2) \rfloor \f$ as in Breiman's original paper),
457  * - RF_SQRT (default, the number of features considered for each split is \f$ \lfloor \sqrt{n_f} + 0.5 \rfloor \f$)
458  * - RF_ALL (all features are considered for each split)
459  */
461  {
462  vigra_precondition(in == RF_LOG ||
463  in == RF_SQRT||
464  in == RF_ALL,
465  "RandomForestOptions()::features_per_node():"
466  "input must be of type RF_LOG or RF_SQRT");
467  mtry_switch_ = in;
468  return *this;
469  }
470 
471  /**\brief Set mtry to a constant value
472  *
473  * mtry is the number of columns/variates/variables randomly chosen
474  * to select the best split from.
475  *
476  */
478  {
479  mtry_ = in;
480  mtry_switch_ = RF_CONST;
481  return *this;
482  }
483 
484  /**\brief use a external function to calculate mtry
485  *
486  * \param in function pointer that takes int (number of columns
487  * of the and outputs int (mtry)
488  */
490  {
491  mtry_func_ = in;
492  mtry_switch_ = RF_FUNCTION;
493  return *this;
494  }
495 
496  /** How many trees to create?
497  *
498  * <br> Default: 255.
499  */
500  RandomForestOptions & tree_count(unsigned int in)
501  {
502  tree_count_ = in;
503  return *this;
504  }
505 
506  /**\brief Number of examples required for a node to be split.
507  *
508  * When the number of examples in a node is below this number,
509  * the node is not split even if class separation is not yet perfect.
510  * Instead, the node returns the proportion of each class
511  * (among the remaining examples) during the prediction phase.
512  * <br> Default: 1 (complete growing)
513  */
515  {
516  min_split_node_size_ = in;
517  return *this;
518  }
519 };
520 
521 
522 /* \brief problem types
523  */
524 enum Problem_t{REGRESSION, CLASSIFICATION, CHECKLATER};
525 
526 
527 /** \brief problem specification class for the random forest.
528  *
529  * This class contains all the problem specific parameters the random
530  * forest needs for learning. Specification of an instance of this class
531  * is optional as all necessary fields will be computed prior to learning
532  * if not specified.
533  *
534  * if needed usage is similar to that of RandomForestOptions
535  */
536 
537 template<class LabelType = double>
539 {
540 
541 
542 public:
543 
544  /** \brief problem class
545  */
546 
547  typedef LabelType Label_t;
548  ArrayVector<Label_t> classes;
550  typedef std::map<std::string, double_array> map_type;
551 
552  int column_count_; // number of features
553  int class_count_; // number of classes
554  int row_count_; // number of samples
555 
556  int actual_mtry_; // mtry used in training
557  int actual_msample_; // number if in-bag samples per tree
558 
559  Problem_t problem_type_; // classification or regression
560 
561  int used_; // this ProblemSpec is valid
562  ArrayVector<double> class_weights_; // if classes have different importance
563  int is_weighted_; // class_weights_ are used
564  double precision_; // termination criterion for regression loss
565  int response_size_;
566 
567  template<class T>
568  void to_classlabel(int index, T & out) const
569  {
570  out = T(classes[index]);
571  }
572  template<class T>
573  int to_classIndex(T index) const
574  {
575  return std::find(classes.begin(), classes.end(), index) - classes.begin();
576  }
577 
578  #define EQUALS(field) field(rhs.field)
579  ProblemSpec(ProblemSpec const & rhs)
580  :
581  EQUALS(column_count_),
582  EQUALS(class_count_),
583  EQUALS(row_count_),
584  EQUALS(actual_mtry_),
585  EQUALS(actual_msample_),
586  EQUALS(problem_type_),
587  EQUALS(used_),
588  EQUALS(class_weights_),
589  EQUALS(is_weighted_),
590  EQUALS(precision_),
591  EQUALS(response_size_)
592  {
593  std::back_insert_iterator<ArrayVector<Label_t> >
594  iter(classes);
595  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
596  }
597  #undef EQUALS
598  #define EQUALS(field) field(rhs.field)
599  template<class T>
600  ProblemSpec(ProblemSpec<T> const & rhs)
601  :
602  EQUALS(column_count_),
603  EQUALS(class_count_),
604  EQUALS(row_count_),
605  EQUALS(actual_mtry_),
606  EQUALS(actual_msample_),
607  EQUALS(problem_type_),
608  EQUALS(used_),
609  EQUALS(class_weights_),
610  EQUALS(is_weighted_),
611  EQUALS(precision_),
612  EQUALS(response_size_)
613  {
614  std::back_insert_iterator<ArrayVector<Label_t> >
615  iter(classes);
616  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
617  }
618  #undef EQUALS
619 
620  #define EQUALS(field) (this->field = rhs.field);
621  ProblemSpec & operator=(ProblemSpec const & rhs)
622  {
623  EQUALS(column_count_);
624  EQUALS(class_count_);
625  EQUALS(row_count_);
626  EQUALS(actual_mtry_);
627  EQUALS(actual_msample_);
628  EQUALS(problem_type_);
629  EQUALS(used_);
630  EQUALS(is_weighted_);
631  EQUALS(precision_);
632  EQUALS(response_size_)
633  class_weights_.clear();
634  std::back_insert_iterator<ArrayVector<double> >
635  iter2(class_weights_);
636  std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
637  classes.clear();
638  std::back_insert_iterator<ArrayVector<Label_t> >
639  iter(classes);
640  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
641  return *this;
642  }
643 
644  template<class T>
645  ProblemSpec<Label_t> & operator=(ProblemSpec<T> const & rhs)
646  {
647  EQUALS(column_count_);
648  EQUALS(class_count_);
649  EQUALS(row_count_);
650  EQUALS(actual_mtry_);
651  EQUALS(actual_msample_);
652  EQUALS(problem_type_);
653  EQUALS(used_);
654  EQUALS(is_weighted_);
655  EQUALS(precision_);
656  EQUALS(response_size_)
657  class_weights_.clear();
658  std::back_insert_iterator<ArrayVector<double> >
659  iter2(class_weights_);
660  std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
661  classes.clear();
662  std::back_insert_iterator<ArrayVector<Label_t> >
663  iter(classes);
664  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
665  return *this;
666  }
667  #undef EQUALS
668 
669  template<class T>
670  bool operator==(ProblemSpec<T> const & rhs)
671  {
672  bool result = true;
673  #define COMPARE(field) result = result && (this->field == rhs.field);
674  COMPARE(column_count_);
675  COMPARE(class_count_);
676  COMPARE(row_count_);
677  COMPARE(actual_mtry_);
678  COMPARE(actual_msample_);
679  COMPARE(problem_type_);
680  COMPARE(is_weighted_);
681  COMPARE(precision_);
682  COMPARE(used_);
683  COMPARE(class_weights_);
684  COMPARE(classes);
685  COMPARE(response_size_)
686  #undef COMPARE
687  return result;
688  }
689 
690  bool operator!=(ProblemSpec & rhs)
691  {
692  return !(*this == rhs);
693  }
694 
695 
696  size_t serialized_size() const
697  {
698  return 10 + class_count_ *int(is_weighted_+1);
699  }
700 
701 
702  template<class Iter>
703  void unserialize(Iter const & begin, Iter const & end)
704  {
705  Iter iter = begin;
706  vigra_precondition(end - begin >= 10,
707  "ProblemSpec::unserialize():"
708  "wrong number of parameters");
709  #define PULL(item_, type_) item_ = type_(*iter); ++iter;
710  PULL(column_count_,int);
711  PULL(class_count_, int);
712 
713  vigra_precondition(end - begin >= 10 + class_count_,
714  "ProblemSpec::unserialize(): 1");
715  PULL(row_count_, int);
716  PULL(actual_mtry_,int);
717  PULL(actual_msample_, int);
718  PULL(problem_type_, Problem_t);
719  PULL(is_weighted_, int);
720  PULL(used_, int);
721  PULL(precision_, double);
722  PULL(response_size_, int);
723  if(is_weighted_)
724  {
725  vigra_precondition(end - begin == 10 + 2*class_count_,
726  "ProblemSpec::unserialize(): 2");
727  class_weights_.insert(class_weights_.end(),
728  iter,
729  iter + class_count_);
730  iter += class_count_;
731  }
732  classes.insert(classes.end(), iter, end);
733  #undef PULL
734  }
735 
736 
737  template<class Iter>
738  void serialize(Iter const & begin, Iter const & end) const
739  {
740  Iter iter = begin;
741  vigra_precondition(end - begin == serialized_size(),
742  "RandomForestOptions::serialize():"
743  "wrong number of parameters");
744  #define PUSH(item_) *iter = double(item_); ++iter;
745  PUSH(column_count_);
746  PUSH(class_count_)
747  PUSH(row_count_);
748  PUSH(actual_mtry_);
749  PUSH(actual_msample_);
750  PUSH(problem_type_);
751  PUSH(is_weighted_);
752  PUSH(used_);
753  PUSH(precision_);
754  PUSH(response_size_);
755  if(is_weighted_)
756  {
757  std::copy(class_weights_.begin(),
758  class_weights_.end(),
759  iter);
760  iter += class_count_;
761  }
762  std::copy(classes.begin(),
763  classes.end(),
764  iter);
765  #undef PUSH
766  }
767 
768  void make_from_map(map_type & in) // -> const: .operator[] -> .find
769  {
770  #define PULL(item_, type_) item_ = type_(in[#item_][0]);
771  PULL(column_count_,int);
772  PULL(class_count_, int);
773  PULL(row_count_, int);
774  PULL(actual_mtry_,int);
775  PULL(actual_msample_, int);
776  PULL(problem_type_, (Problem_t)int);
777  PULL(is_weighted_, int);
778  PULL(used_, int);
779  PULL(precision_, double);
780  PULL(response_size_, int);
781  class_weights_ = in["class_weights_"];
782  #undef PULL
783  }
784  void make_map(map_type & in) const
785  {
786  #define PUSH(item_) in[#item_] = double_array(1, double(item_));
787  PUSH(column_count_);
788  PUSH(class_count_)
789  PUSH(row_count_);
790  PUSH(actual_mtry_);
791  PUSH(actual_msample_);
792  PUSH(problem_type_);
793  PUSH(is_weighted_);
794  PUSH(used_);
795  PUSH(precision_);
796  PUSH(response_size_);
797  in["class_weights_"] = class_weights_;
798  #undef PUSH
799  }
800 
801  /**\brief set default values (-> values not set)
802  */
804  : column_count_(0),
805  class_count_(0),
806  row_count_(0),
807  actual_mtry_(0),
808  actual_msample_(0),
809  problem_type_(CHECKLATER),
810  used_(false),
811  is_weighted_(false),
812  precision_(0.0),
813  response_size_(1)
814  {}
815 
816 
817  ProblemSpec & column_count(int in)
818  {
819  column_count_ = in;
820  return *this;
821  }
822 
823  /**\brief supply with class labels -
824  *
825  * the preprocessor will not calculate the labels needed in this case.
826  */
827  template<class C_Iter>
828  ProblemSpec & classes_(C_Iter begin, C_Iter end)
829  {
830  classes.clear();
831  int size = end-begin;
832  for(int k=0; k<size; ++k, ++begin)
833  classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin));
834  class_count_ = size;
835  return *this;
836  }
837 
838  /** \brief supply with class weights -
839  *
840  * this is the only case where you would really have to
841  * create a ProblemSpec object.
842  */
843  template<class W_Iter>
844  ProblemSpec & class_weights(W_Iter begin, W_Iter end)
845  {
846  class_weights_.clear();
847  class_weights_.insert(class_weights_.end(), begin, end);
848  is_weighted_ = true;
849  return *this;
850  }
851 
852 
853 
854  void clear()
855  {
856  used_ = false;
857  classes.clear();
858  class_weights_.clear();
859  column_count_ = 0 ;
860  class_count_ = 0;
861  actual_mtry_ = 0;
862  actual_msample_ = 0;
863  problem_type_ = CHECKLATER;
864  is_weighted_ = false;
865  precision_ = 0.0;
866  response_size_ = 0;
867 
868  }
869 
870  bool used() const
871  {
872  return used_ != 0;
873  }
874 };
875 
876 
877 //@}
878 
879 
880 
881 /**\brief Standard early stopping criterion
882  *
883  * Stop if region.size() < min_split_node_size_;
884  */
886 {
887  public:
888  int min_split_node_size_;
889 
890  template<class Opt>
891  EarlyStoppStd(Opt opt)
892  : min_split_node_size_(opt.min_split_node_size_)
893  {}
894 
895  template<class T>
896  void set_external_parameters(ProblemSpec<T>const &, int /* tree_count */ = 0, bool /* is_weighted_ */ = false)
897  {}
898 
899  template<class Region>
900  bool operator()(Region& region)
901  {
902  return region.size() < min_split_node_size_;
903  }
904 
905  template<class WeightIter, class T, class C>
906  bool after_prediction(WeightIter, int /* k */, MultiArrayView<2, T, C> /* prob */, double /* totalCt */)
907  {
908  return false;
909  }
910 };
911 
912 
913 } // namespace vigra
914 
915 #endif //VIGRA_RF_COMMON_HXX
RandomForestOptions & features_per_node(RF_OptionTag in)
use built in mapping to calculate mtry
Definition: rf_common.hxx:460
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag
Definition: rf_common.hxx:131
RandomForestOptions & samples_per_tree(double in)
specify the fraction of the total number of samples used per tree for learning.
Definition: rf_common.hxx:411
RandomForestOptions & features_per_node(int(*in)(int))
use a external function to calculate mtry
Definition: rf_common.hxx:489
const_iterator begin() const
Definition: array_vector.hxx:223
RandomForestOptions & samples_per_tree(int in)
directly specify the number of samples per tree
Definition: rf_common.hxx:423
RandomForestOptions & samples_per_tree(int(*in)(int))
use external function to calculate the number of samples each tree should be learnt with...
Definition: rf_common.hxx:436
problem specification class for the random forest.
Definition: rf_common.hxx:538
LabelType Label_t
problem class
Definition: rf_common.hxx:547
RandomForestOptions & min_split_node_size(int in)
Number of examples required for a node to be split.
Definition: rf_common.hxx:514
RandomForestOptions & features_per_node(int in)
Set mtry to a constant value.
Definition: rf_common.hxx:477
Standard early stopping criterion.
Definition: rf_common.hxx:885
ProblemSpec & classes_(C_Iter begin, C_Iter end)
supply with class labels -
Definition: rf_common.hxx:828
RandomForestOptions()
create a RandomForestOptions object with default initialisation.
Definition: rf_common.hxx:346
ProblemSpec & class_weights(W_Iter begin, W_Iter end)
supply with class weights -
Definition: rf_common.hxx:844
RandomForestOptions & sample_with_replacement(bool in)
sample from training population with or without replacement?
Definition: rf_common.hxx:397
RandomForestOptions & predict_weighted()
weight each tree with number of samples in that node
Definition: rf_common.hxx:445
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:704
Options object for the random forest.
Definition: rf_common.hxx:170
RandomForestOptions & use_stratification(RF_OptionTag in)
specify stratification strategy
Definition: rf_common.hxx:374
RandomForestOptions & tree_count(unsigned int in)
Definition: rf_common.hxx:500
const_iterator end() const
Definition: array_vector.hxx:237
ProblemSpec()
set default values (-> values not set)
Definition: rf_common.hxx:803
RF_OptionTag
Definition: rf_common.hxx:140

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.1 (Fri May 19 2017)