[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 
37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
39 
40 #include <iostream>
41 #include <algorithm>
42 #include <map>
43 #include <set>
44 #include <list>
45 #include <numeric>
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
49 #include "matrix.hxx"
50 #include "metaprogramming.hxx"
51 #include "random.hxx"
52 #include "functorexpression.hxx"
53 #include "random_forest/rf_common.hxx"
54 #include "random_forest/rf_nodeproxy.hxx"
55 #include "random_forest/rf_split.hxx"
56 #include "random_forest/rf_decisionTree.hxx"
57 #include "random_forest/rf_visitors.hxx"
58 #include "random_forest/rf_region.hxx"
59 #include "sampling.hxx"
60 #include "random_forest/rf_preprocessing.hxx"
61 #include "random_forest/rf_online_prediction_set.hxx"
62 #include "random_forest/rf_earlystopping.hxx"
63 #include "random_forest/rf_ridge_split.hxx"
64 namespace vigra
65 {
66 
67 /** \addtogroup MachineLearning Machine Learning
68 
69  This module provides classification algorithms that map
70  features to labels or label probabilities.
71  Look at the \ref vigra::RandomForest class (for implementation version 2) or the
72  \ref vigra::rf3::random_forest() factory function (for implementation version 3)
73  for an overview of the functionality as well as use cases.
74 **/
75 
76 namespace detail
77 {
78 
79 
80 
81 /* \brief sampling option factory function
82  */
83 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
84 {
85  SamplerOptions return_opt;
86  return_opt.withReplacement(RF_opt.sample_with_replacement_);
87  return_opt.stratified(RF_opt.stratification_method_ == RF_EQUAL);
88  return return_opt;
89 }
90 }//namespace detail
91 
92 /** \brief Random forest version 2 (see also \ref vigra::rf3::RandomForest for version 3)
93  *
94  * \ingroup MachineLearning
95  *
96  * \tparam <LabelType = double> Type used for predicted labels.
97  * \tparam <PreprocessorTag = ClassificationTag> Class used to preprocess
98  * the input while learning and predicting. Currently Available:
99  * ClassificationTag and RegressionTag. It is recommended to use
100  * Splitfunctor::Preprocessor_t while using custom splitfunctors
101  * as they may need the data to be in a different format.
102  * \sa Preprocessor
103  *
104  * Simple usage for classification (regression is not yet supported):
105  * look at RandomForest::learn() as well as RandomForestOptions() for additional
106  * options.
107  *
108  * \code
109  * using namespace vigra;
110  * using namespace rf;
111  * typedef xxx feature_t; \\ replace xxx with whichever type
112  * typedef yyy label_t; \\ likewise
113  *
114  * // allocate the training data
115  * MultiArrayView<2, feature_t> f = get_training_features();
116  * MultiArrayView<2, label_t> l = get_training_labels();
117  *
118  * RandomForest<label_t> rf;
119  *
120  * // construct visitor to calculate out-of-bag error
121  * visitors::OOB_Error oob_v;
122  *
123  * // perform training
124  * rf.learn(f, l, visitors::create_visitor(oob_v));
125  *
126  * std::cout << "the out-of-bag error is: " << oob_v.oob_breiman << "\n";
127  *
128  * // get features for new data to be used for prediction
129  * MultiArrayView<2, feature_t> pf = get_features();
130  *
131  * // allocate space for the response (pf.shape(0) is the number of samples)
132  * MultiArrayView<2, label_t> prediction(pf.shape(0), 1);
133  * MultiArrayView<2, double> prob(pf.shape(0), rf.class_count());
134  *
135  * // perform prediction on new data
136  * rf.predictLabels(pf, prediction);
137  * rf.predictProbabilities(pf, prob);
138  *
139  * \endcode
140  *
141  * Additional information such as Variable Importance measures are accessed
142  * via Visitors defined in rf::visitors.
143  * Have a look at rf::split for other splitting methods.
144  *
145 */
146 template <class LabelType = double , class PreprocessorTag = ClassificationTag >
148 {
149 
150  public:
151  //public typedefs
153  typedef detail::DecisionTree DecisionTree_t;
155  typedef GiniSplit Default_Split_t;
159  StackEntry_t;
160  typedef LabelType LabelT;
161 
162  //problem independent data.
163  Options_t options_;
164  //problem dependent data members - is only set if
165  //a copy constructor, some sort of import
166  //function or the learn function is called
168  ProblemSpec_t ext_param_;
169  /*mutable ArrayVector<int> tree_indices_;*/
170  rf::visitors::OnlineLearnVisitor online_visitor_;
171 
172 
173  void reset()
174  {
175  ext_param_.clear();
176  trees_.clear();
177  }
178 
179  public:
180 
181  /** \name Constructors
182  * Note: No copy constructor specified as no pointers are manipulated
183  * in this class
184 
185  * @{
186  */
187 
188  /**\brief default constructor
189  *
190  * \param options general options to the Random Forest. Must be of Type
191  * Options_t
192  * \param ext_param problem specific values that can be supplied
193  * additionally. (class weights , labels etc)
194  * \sa RandomForestOptions, ProblemSpec
195  *
196  */
199  :
200  options_(options),
201  ext_param_(ext_param)/*,
202  tree_indices_(options.tree_count_,0)*/
203  {
204  /*for(int ii = 0 ; ii < int(tree_indices_.size()); ++ii)
205  tree_indices_[ii] = ii;*/
206  }
207 
208  /**\brief Create RF from external source
209  * \param treeCount Number of trees to add.
210  * \param topology_begin
211  * Iterator to a Container where the topology_ data
212  * of the trees are stored.
213  * Iterator should support at least treeCount forward
214  * iterations. (i.e. topology_end - topology_begin >= treeCount
215  * \param parameter_begin
216  * iterator to a Container where the parameters_ data
217  * of the trees are stored. Iterator should support at
218  * least treeCount forward iterations.
219  * \param problem_spec
220  * Extrinsic parameters that specify the problem e.g.
221  * ClassCount, featureCount etc.
222  * \param options (optional) specify options used to train the original
223  * Random forest. This parameter is not used anywhere
224  * during prediction and thus is optional.
225  *
226  */
227  template<class TopologyIterator, class ParameterIterator>
228  RandomForest(int treeCount,
229  TopologyIterator topology_begin,
230  ParameterIterator parameter_begin,
231  ProblemSpec_t const & problem_spec,
232  Options_t const & options = Options_t())
233  :
234  trees_(treeCount, DecisionTree_t(problem_spec)),
235  ext_param_(problem_spec),
236  options_(options)
237  {
238  /* TODO: This constructor may be replaced by a Constructor using
239  * NodeProxy iterators to encapsulate the underlying data type.
240  */
241  for(int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
242  {
243  trees_[k].topology_ = *topology_begin;
244  trees_[k].parameters_ = *parameter_begin;
245  }
246  }
247 
248  /** @} */
249 
250 
251  /** \name Data Access
252  * data access interface - usage of member variables is deprecated
253  *
254  * @{
255  */
256 
257  /**\brief return external parameters for viewing
258  * \return ProblemSpec_t
259  */
260  ProblemSpec_t const & ext_param() const
261  {
262  vigra_precondition(ext_param_.used() == true,
263  "RandomForest::ext_param(): "
264  "Random forest has not been trained yet.");
265  return ext_param_;
266  }
267 
268  /**\brief set external parameters
269  *
270  * \param in external parameters to be set
271  *
272  * set external parameters explicitly.
273  * If Random Forest has not been trained the preprocessor will
274  * either ignore filling values set this way or will throw an exception
275  * if values specified manually do not match the value calculated
276  & during the preparation step.
277  */
278  void set_ext_param(ProblemSpec_t const & in)
279  {
280  ignore_argument(in);
281  vigra_precondition(ext_param_.used() == false,
282  "RandomForest::set_ext_param():"
283  "Random forest has been trained! Call reset()"
284  "before specifying new extrinsic parameters.");
285  }
286 
287  /**\brief access random forest options
288  *
289  * \return random forest options
290  */
292  {
293  return options_;
294  }
295 
296 
297  /**\brief access const random forest options
298  *
299  * \return const Option_t
300  */
301  Options_t const & options() const
302  {
303  return options_;
304  }
305 
306  /**\brief access const trees
307  */
308  DecisionTree_t const & tree(int index) const
309  {
310  return trees_[index];
311  }
312 
313  /**\brief access trees
314  */
315  DecisionTree_t & tree(int index)
316  {
317  return trees_[index];
318  }
319 
320  /**\brief return number of features used while
321  * training.
322  */
323  int feature_count() const
324  {
325  return ext_param_.column_count_;
326  }
327 
328 
329  /**\brief return number of features used while
330  * training.
331  *
332  * deprecated. Use feature_count() instead.
333  */
334  int column_count() const
335  {
336  return ext_param_.column_count_;
337  }
338 
339  /**\brief return number of classes used while
340  * training.
341  */
342  int class_count() const
343  {
344  return ext_param_.class_count_;
345  }
346 
347  /**\brief return number of trees
348  */
349  int tree_count() const
350  {
351  return options_.tree_count_;
352  }
353 
354  /** @} */
355 
356  /**\name Learning
357  * Following functions differ in the degree of customization
358  * allowed
359  *
360  * @{
361  */
362 
363  /**\brief learn on data with custom config and random number generator
364  *
365  * \param features a N x M matrix containing N samples with M
366  * features
367  * \param response a N x D matrix containing the corresponding
368  * response. Current split functors assume D to
369  * be 1 and ignore any additional columns.
370  * This is not enforced to allow future support
371  * for uncertain labels, label independent strata etc.
372  * The Preprocessor specified during construction
373  * should be able to handle features and labels
374  * features and the labels.
375  * see also: SplitFunctor, Preprocessing
376  *
377  * \param visitor visitor which is to be applied after each split,
378  * tree and at the end. Use rf_default() for using
379  * default value. (No Visitors)
380  * see also: rf::visitors
381  * \param split split functor to be used to calculate each split
382  * use rf_default() for using default value. (GiniSplit)
383  * see also: rf::split
384  * \param stop
385  * predicate to be used to calculate each split
386  * use rf_default() for using default value. (EarlyStoppStd)
387  * \param random RandomNumberGenerator to be used. Use
388  * rf_default() to use default value.(RandomMT19337)
389  *
390  *
391  */
392  template <class U, class C1,
393  class U2,class C2,
394  class Split_t,
395  class Stop_t,
396  class Visitor_t,
397  class Random_t>
398  void learn( MultiArrayView<2, U, C1> const & features,
399  MultiArrayView<2, U2,C2> const & response,
400  Visitor_t visitor,
401  Split_t split,
402  Stop_t stop,
403  Random_t const & random);
404 
405  template <class U, class C1,
406  class U2,class C2,
407  class Split_t,
408  class Stop_t,
409  class Visitor_t>
410  void learn( MultiArrayView<2, U, C1> const & features,
411  MultiArrayView<2, U2,C2> const & response,
412  Visitor_t visitor,
413  Split_t split,
414  Stop_t stop)
415 
416  {
418  learn( features,
419  response,
420  visitor,
421  split,
422  stop,
423  rnd);
424  }
425 
426  template <class U, class C1, class U2,class C2, class Visitor_t>
427  void learn( MultiArrayView<2, U, C1> const & features,
428  MultiArrayView<2, U2,C2> const & labels,
429  Visitor_t visitor)
430  {
431  learn( features,
432  labels,
433  visitor,
434  rf_default(),
435  rf_default());
436  }
437 
438  template <class U, class C1, class U2,class C2,
439  class Visitor_t, class Split_t>
440  void learn( MultiArrayView<2, U, C1> const & features,
441  MultiArrayView<2, U2,C2> const & labels,
442  Visitor_t visitor,
443  Split_t split)
444  {
445  learn( features,
446  labels,
447  visitor,
448  split,
449  rf_default());
450  }
451 
452  /**\brief learn on data with default configuration
453  *
454  * \param features a N x M matrix containing N samples with M
455  * features
456  * \param labels a N x D matrix containing the corresponding
457  * N labels. Current split functors assume D to
458  * be 1 and ignore any additional columns.
459  * this is not enforced to allow future support
460  * for uncertain labels.
461  *
462  * learning is done with:
463  *
464  * \sa rf::split, EarlyStoppStd
465  *
466  * - Randomly seeded random number generator
467  * - default gini split functor as described by Breiman
468  * - default The standard early stopping criterion
469  */
470  template <class U, class C1, class U2,class C2>
471  void learn( MultiArrayView<2, U, C1> const & features,
472  MultiArrayView<2, U2,C2> const & labels)
473  {
474  learn( features,
475  labels,
476  rf_default(),
477  rf_default(),
478  rf_default());
479  }
480 
481 
482  template<class U,class C1,
483  class U2, class C2,
484  class Split_t,
485  class Stop_t,
486  class Visitor_t,
487  class Random_t>
488  void onlineLearn( MultiArrayView<2,U,C1> const & features,
489  MultiArrayView<2,U2,C2> const & response,
490  int new_start_index,
491  Visitor_t visitor_,
492  Split_t split_,
493  Stop_t stop_,
494  Random_t & random,
495  bool adjust_thresholds=false);
496 
497  template <class U, class C1, class U2,class C2>
498  void onlineLearn( MultiArrayView<2, U, C1> const & features,
499  MultiArrayView<2, U2,C2> const & labels,int new_start_index,bool adjust_thresholds=false)
500  {
502  onlineLearn(features,
503  labels,
504  new_start_index,
505  rf_default(),
506  rf_default(),
507  rf_default(),
508  rnd,
509  adjust_thresholds);
510  }
511 
512  template<class U,class C1,
513  class U2, class C2,
514  class Split_t,
515  class Stop_t,
516  class Visitor_t,
517  class Random_t>
518  void reLearnTree(MultiArrayView<2,U,C1> const & features,
519  MultiArrayView<2,U2,C2> const & response,
520  int treeId,
521  Visitor_t visitor_,
522  Split_t split_,
523  Stop_t stop_,
524  Random_t & random);
525 
526  template<class U, class C1, class U2, class C2>
527  void reLearnTree(MultiArrayView<2, U, C1> const & features,
528  MultiArrayView<2, U2, C2> const & labels,
529  int treeId)
530  {
531  RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
532  reLearnTree(features,
533  labels,
534  treeId,
535  rf_default(),
536  rf_default(),
537  rf_default(),
538  rnd);
539  }
540 
541  /** @} */
542 
543 
544 
545  /**\name Prediction
546  *
547  * @{
548  */
549 
550  /** \brief predict a label given a feature.
551  *
552  * \param features: a 1 by featureCount matrix containing
553  * data point to be predicted (this only works in
554  * classification setting)
555  * \param stop: early stopping criterion
556  * \return double value representing class. You can use the
557  * predictLabels() function together with the
558  * rf.external_parameter().class_type_ attribute
559  * to get back the same type used during learning.
560  */
561  template <class U, class C, class Stop>
562  LabelType predictLabel(MultiArrayView<2, U, C>const & features, Stop & stop) const;
563 
564  template <class U, class C>
565  LabelType predictLabel(MultiArrayView<2, U, C>const & features)
566  {
567  return predictLabel(features, rf_default());
568  }
569  /** \brief predict a label with features and class priors
570  *
571  * \param features: same as above.
572  * \param prior: iterator to prior weighting of classes
573  * \return sam as above.
574  */
575  template <class U, class C>
576  LabelType predictLabel(MultiArrayView<2, U, C> const & features,
577  ArrayVectorView<double> prior) const;
578 
579  /** \brief predict multiple labels with given features
580  *
581  * \param features: a n by featureCount matrix containing
582  * data point to be predicted (this only works in
583  * classification setting)
584  * \param labels: a n by 1 matrix passed by reference to store
585  * output.
586  *
587  * If the input contains an NaN value, an precondition exception is thrown.
588  */
589  template <class U, class C1, class T, class C2>
591  MultiArrayView<2, T, C2> & labels) const
592  {
593  vigra_precondition(features.shape(0) == labels.shape(0),
594  "RandomForest::predictLabels(): Label array has wrong size.");
595  for(int k=0; k<features.shape(0); ++k)
596  {
597  vigra_precondition(!detail::contains_nan(rowVector(features, k)),
598  "RandomForest::predictLabels(): NaN in feature matrix.");
599  labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default()));
600  }
601  }
602 
603  /** \brief predict multiple labels with given features
604  *
605  * \param features: a n by featureCount matrix containing
606  * data point to be predicted (this only works in
607  * classification setting)
608  * \param labels: a n by 1 matrix passed by reference to store
609  * output.
610  * \param nanLabel: label to be returned for the row of the input that
611  * contain an NaN value.
612  */
613  template <class U, class C1, class T, class C2>
615  MultiArrayView<2, T, C2> & labels,
616  LabelType nanLabel) const
617  {
618  vigra_precondition(features.shape(0) == labels.shape(0),
619  "RandomForest::predictLabels(): Label array has wrong size.");
620  for(int k=0; k<features.shape(0); ++k)
621  {
622  if(detail::contains_nan(rowVector(features, k)))
623  labels(k,0) = nanLabel;
624  else
625  labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default()));
626  }
627  }
628 
629  /** \brief predict multiple labels with given features
630  *
631  * \param features: a n by featureCount matrix containing
632  * data point to be predicted (this only works in
633  * classification setting)
634  * \param labels: a n by 1 matrix passed by reference to store
635  * output.
636  * \param stop: an early stopping criterion.
637  */
638  template <class U, class C1, class T, class C2, class Stop>
640  MultiArrayView<2, T, C2> & labels,
641  Stop & stop) const
642  {
643  vigra_precondition(features.shape(0) == labels.shape(0),
644  "RandomForest::predictLabels(): Label array has wrong size.");
645  for(int k=0; k<features.shape(0); ++k)
646  labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), stop));
647  }
648  /** \brief predict the class probabilities for multiple labels
649  *
650  * \param features same as above
651  * \param prob a n x class_count_ matrix. passed by reference to
652  * save class probabilities
653  * \param stop earlystopping criterion
654  * \sa EarlyStopping
655 
656  When a row of the feature array contains an NaN, the corresponding instance
657  cannot belong to any of the classes. The corresponding row in the probability
658  array will therefore contain all zeros.
659  */
660  template <class U, class C1, class T, class C2, class Stop>
661  void predictProbabilities(MultiArrayView<2, U, C1>const & features,
663  Stop & stop) const;
664  template <class T1,class T2, class C>
665  void predictProbabilities(OnlinePredictionSet<T1> & predictionSet,
666  MultiArrayView<2, T2, C> & prob);
667 
668  /** \brief predict the class probabilities for multiple labels
669  *
670  * \param features same as above
671  * \param prob a n x class_count_ matrix. passed by reference to
672  * save class probabilities
673  */
674  template <class U, class C1, class T, class C2>
676  MultiArrayView<2, T, C2> & prob) const
677  {
678  predictProbabilities(features, prob, rf_default());
679  }
680 
681  template <class U, class C1, class T, class C2>
682  void predictRaw(MultiArrayView<2, U, C1>const & features,
683  MultiArrayView<2, T, C2> & prob) const;
684 
685 
686  /** @} */
687 
688 };
689 
690 
691 template <class LabelType, class PreprocessorTag>
692 template<class U,class C1,
693  class U2, class C2,
694  class Split_t,
695  class Stop_t,
696  class Visitor_t,
697  class Random_t>
698 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1> const & features,
699  MultiArrayView<2,U2,C2> const & response,
700  int new_start_index,
701  Visitor_t visitor_,
702  Split_t split_,
703  Stop_t stop_,
704  Random_t & random,
705  bool adjust_thresholds)
706 {
707  online_visitor_.activate();
708  online_visitor_.adjust_thresholds=adjust_thresholds;
709 
710  using namespace rf;
711  //typedefs
712  typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
713  typedef UniformIntRandomFunctor<Random_t>
714  RandFunctor_t;
715  // default values and initialization
716  // Value Chooser chooses second argument as value if first argument
717  // is of type RF_DEFAULT. (thanks to template magic - don't care about
718  // it - just smile and wave.
719 
720  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
721  Default_Stop_t default_stop(options_);
722  typename RF_CHOOSER(Stop_t)::type stop
723  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
724  Default_Split_t default_split;
725  typename RF_CHOOSER(Split_t)::type split
726  = RF_CHOOSER(Split_t)::choose(split_, default_split);
727  rf::visitors::StopVisiting stopvisiting;
728  typedef rf::visitors::detail::VisitorNode
729  <rf::visitors::OnlineLearnVisitor,
730  typename RF_CHOOSER(Visitor_t)::type>
731  IntermedVis;
732  IntermedVis
733  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
734  #undef RF_CHOOSER
735  vigra_precondition(options_.prepare_online_learning_,"onlineLearn: online learning must be enabled on RandomForest construction");
736 
737  // Preprocess the data to get something the split functor can work
738  // with. Also fill the ext_param structure by preprocessing
739  // option parameters that could only be completely evaluated
740  // when the training data is known.
741  ext_param_.class_count_=0;
742  Preprocessor_t preprocessor( features, response,
743  options_, ext_param_);
744 
745  // Make stl compatible random functor.
746  RandFunctor_t randint ( random);
747 
748  // Give the Split functor information about the data.
749  split.set_external_parameters(ext_param_);
750  stop.set_external_parameters(ext_param_);
751 
752 
753  //Create poisson samples
754  PoissonSampler<RandomTT800> poisson_sampler(1.0,vigra::Int32(new_start_index),vigra::Int32(ext_param().row_count_));
755 
756  //TODO: visitors for online learning
757  //visitor.visit_at_beginning(*this, preprocessor);
758 
759  // THE MAIN EFFING RF LOOP - YEAY DUDE!
760  for(int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
761  {
762  online_visitor_.tree_id=ii;
763  poisson_sampler.sample();
764  std::map<int,int> leaf_parents;
765  leaf_parents.clear();
766  //Get all the leaf nodes for that sample
767  for(int s=0;s<poisson_sampler.numOfSamples();++s)
768  {
769  int sample=poisson_sampler[s];
770  online_visitor_.current_label=preprocessor.response()(sample,0);
771  online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
772  int leaf=trees_[ii].getToLeaf(rowVector(features,sample),online_visitor_);
773 
774 
775  //Add to the list for that leaf
776  online_visitor_.add_to_index_list(ii,leaf,sample);
777  //TODO: Class count?
778  //Store parent
779  if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
780  {
781  leaf_parents[leaf]=online_visitor_.last_node_id;
782  }
783  }
784 
785 
786  std::map<int,int>::iterator leaf_iterator;
787  for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
788  {
789  int leaf=leaf_iterator->first;
790  int parent=leaf_iterator->second;
791  int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
792  ArrayVector<Int32> indeces;
793  indeces.clear();
794  indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
795  StackEntry_t stack_entry(indeces.begin(),
796  indeces.end(),
797  ext_param_.class_count_);
798 
799 
800  if(parent!=-1)
801  {
802  if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
803  {
804  stack_entry.leftParent=parent;
805  }
806  else
807  {
808  vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,"last_node_id seems to be wrong");
809  stack_entry.rightParent=parent;
810  }
811  }
812  //trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,leaf);
813  trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
814  //Now, the last one moved onto leaf
815  online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
816  //Now it should be classified correctly!
817  }
818 
819  /*visitor
820  .visit_after_tree( *this,
821  preprocessor,
822  poisson_sampler,
823  stack_entry,
824  ii);*/
825  }
826 
827  //visitor.visit_at_end(*this, preprocessor);
828  online_visitor_.deactivate();
829 }
830 
831 template<class LabelType, class PreprocessorTag>
832 template<class U,class C1,
833  class U2, class C2,
834  class Split_t,
835  class Stop_t,
836  class Visitor_t,
837  class Random_t>
839  MultiArrayView<2,U2,C2> const & response,
840  int treeId,
841  Visitor_t visitor_,
842  Split_t split_,
843  Stop_t stop_,
844  Random_t & random)
845 {
846  using namespace rf;
847 
848 
850  RandFunctor_t;
851 
852  // See rf_preprocessing.hxx for more info on this
853  ext_param_.class_count_=0;
855 
856  // default values and initialization
857  // Value Chooser chooses second argument as value if first argument
858  // is of type RF_DEFAULT. (thanks to template magic - don't care about
859  // it - just smile and wave.
860 
861  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
862  Default_Stop_t default_stop(options_);
863  typename RF_CHOOSER(Stop_t)::type stop
864  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
865  Default_Split_t default_split;
866  typename RF_CHOOSER(Split_t)::type split
867  = RF_CHOOSER(Split_t)::choose(split_, default_split);
868  rf::visitors::StopVisiting stopvisiting;
871  typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
872  IntermedVis
873  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
874  #undef RF_CHOOSER
875  vigra_precondition(options_.prepare_online_learning_,"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
876  online_visitor_.activate();
877 
878  // Make stl compatible random functor.
879  RandFunctor_t randint ( random);
880 
881  // Preprocess the data to get something the split functor can work
882  // with. Also fill the ext_param structure by preprocessing
883  // option parameters that could only be completely evaluated
884  // when the training data is known.
885  Preprocessor_t preprocessor( features, response,
886  options_, ext_param_);
887 
888  // Give the Split functor information about the data.
889  split.set_external_parameters(ext_param_);
890  stop.set_external_parameters(ext_param_);
891 
892  /**\todo replace this crappy class out. It uses function pointers.
893  * and is making code slower according to me.
894  * Comment from Nathan: This is copied from Rahul, so me=Rahul
895  */
896  Sampler<Random_t > sampler(preprocessor.strata().begin(),
897  preprocessor.strata().end(),
898  detail::make_sampler_opt(options_)
899  .sampleSize(ext_param().actual_msample_),
900  &random);
901  //initialize First region/node/stack entry
902  sampler
903  .sample();
904 
906  first_stack_entry( sampler.sampledIndices().begin(),
907  sampler.sampledIndices().end(),
908  ext_param_.class_count_);
909  first_stack_entry
910  .set_oob_range( sampler.oobIndices().begin(),
911  sampler.oobIndices().end());
912  online_visitor_.reset_tree(treeId);
913  online_visitor_.tree_id=treeId;
914  trees_[treeId].reset();
915  trees_[treeId]
916  .learn( preprocessor.features(),
917  preprocessor.response(),
918  first_stack_entry,
919  split,
920  stop,
921  visitor,
922  randint);
923  visitor
924  .visit_after_tree( *this,
925  preprocessor,
926  sampler,
927  first_stack_entry,
928  treeId);
929 
930  online_visitor_.deactivate();
931 }
932 
933 template <class LabelType, class PreprocessorTag>
934 template <class U, class C1,
935  class U2,class C2,
936  class Split_t,
937  class Stop_t,
938  class Visitor_t,
939  class Random_t>
942  MultiArrayView<2, U2,C2> const & response,
943  Visitor_t visitor_,
944  Split_t split_,
945  Stop_t stop_,
946  Random_t const & random)
947 {
948  using namespace rf;
949  //this->reset();
950  //typedefs
952  RandFunctor_t;
953 
954  // See rf_preprocessing.hxx for more info on this
956 
957  vigra_precondition(features.shape(0) == response.shape(0),
958  "RandomForest::learn(): shape mismatch between features and response.");
959 
960  // default values and initialization
961  // Value Chooser chooses second argument as value if first argument
962  // is of type RF_DEFAULT. (thanks to template magic - don't care about
963  // it - just smile and wave).
964 
965  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
966  Default_Stop_t default_stop(options_);
967  typename RF_CHOOSER(Stop_t)::type stop
968  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
969  Default_Split_t default_split;
970  typename RF_CHOOSER(Split_t)::type split
971  = RF_CHOOSER(Split_t)::choose(split_, default_split);
972  rf::visitors::StopVisiting stopvisiting;
975  typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
976  IntermedVis
977  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
978  #undef RF_CHOOSER
979  if(options_.prepare_online_learning_)
980  online_visitor_.activate();
981  else
982  online_visitor_.deactivate();
983 
984 
985  // Make stl compatible random functor.
986  RandFunctor_t randint ( random);
987 
988 
989  // Preprocess the data to get something the split functor can work
990  // with. Also fill the ext_param structure by preprocessing
991  // option parameters that could only be completely evaluated
992  // when the training data is known.
993  Preprocessor_t preprocessor( features, response,
994  options_, ext_param_);
995 
996  // Give the Split functor information about the data.
997  split.set_external_parameters(ext_param_);
998  stop.set_external_parameters(ext_param_);
999 
1000 
1001  //initialize trees.
1002  trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
1003 
1004  Sampler<Random_t > sampler(preprocessor.strata().begin(),
1005  preprocessor.strata().end(),
1006  detail::make_sampler_opt(options_)
1007  .sampleSize(ext_param().actual_msample_),
1008  &random);
1009 
1010  visitor.visit_at_beginning(*this, preprocessor);
1011  // THE MAIN EFFING RF LOOP - YEAY DUDE!
1012 
1013  for(int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
1014  {
1015  //initialize First region/node/stack entry
1016  sampler
1017  .sample();
1018  StackEntry_t
1019  first_stack_entry( sampler.sampledIndices().begin(),
1020  sampler.sampledIndices().end(),
1021  ext_param_.class_count_);
1022  first_stack_entry
1023  .set_oob_range( sampler.oobIndices().begin(),
1024  sampler.oobIndices().end());
1025  trees_[ii]
1026  .learn( preprocessor.features(),
1027  preprocessor.response(),
1028  first_stack_entry,
1029  split,
1030  stop,
1031  visitor,
1032  randint);
1033  visitor
1034  .visit_after_tree( *this,
1035  preprocessor,
1036  sampler,
1037  first_stack_entry,
1038  ii);
1039  }
1040 
1041  visitor.visit_at_end(*this, preprocessor);
1042  // Only for online learning?
1043  online_visitor_.deactivate();
1044 }
1045 
1046 
1047 
1048 
1049 template <class LabelType, class Tag>
1050 template <class U, class C, class Stop>
1052  ::predictLabel(MultiArrayView<2, U, C> const & features, Stop & stop) const
1053 {
1054  vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1055  "RandomForestn::predictLabel():"
1056  " Too few columns in feature matrix.");
1057  vigra_precondition(rowCount(features) == 1,
1058  "RandomForestn::predictLabel():"
1059  " Feature matrix must have a singlerow.");
1060  MultiArray<2, double> probabilities(Shape2(1, ext_param_.class_count_), 0.0);
1061  LabelType d;
1062  predictProbabilities(features, probabilities, stop);
1063  ext_param_.to_classlabel(argMax(probabilities), d);
1064  return d;
1065 }
1066 
1067 
1068 //Same thing as above with priors for each label !!!
1069 template <class LabelType, class PreprocessorTag>
1070 template <class U, class C>
1073  ArrayVectorView<double> priors) const
1074 {
1075  using namespace functor;
1076  vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1077  "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1078  vigra_precondition(rowCount(features) == 1,
1079  "RandomForestn::predictLabel():"
1080  " Feature matrix must have a single row.");
1081  Matrix<double> prob(1,ext_param_.class_count_);
1082  predictProbabilities(features, prob);
1083  std::transform( prob.begin(), prob.end(),
1084  priors.begin(), prob.begin(),
1085  Arg1()*Arg2());
1086  LabelType d;
1087  ext_param_.to_classlabel(argMax(prob), d);
1088  return d;
1089 }
1090 
1091 template<class LabelType,class PreprocessorTag>
1092 template <class T1,class T2, class C>
1094  ::predictProbabilities(OnlinePredictionSet<T1> & predictionSet,
1095  MultiArrayView<2, T2, C> & prob)
1096 {
1097  //Features are n xp
1098  //prob is n x NumOfLabel probability for each feature in each class
1099 
1100  vigra_precondition(rowCount(predictionSet.features) == rowCount(prob),
1101  "RandomFroest::predictProbabilities():"
1102  " Feature matrix and probability matrix size mismatch.");
1103  // num of features must be bigger than num of features in Random forest training
1104  // but why bigger?
1105  vigra_precondition( columnCount(predictionSet.features) >= ext_param_.column_count_,
1106  "RandomForestn::predictProbabilities():"
1107  " Too few columns in feature matrix.");
1108  vigra_precondition( columnCount(prob)
1109  == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1110  "RandomForestn::predictProbabilities():"
1111  " Probability matrix must have as many columns as there are classes.");
1112  prob.init(0.0);
1113  //store total weights
1114  std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1115  //Go through all trees
1116  int set_id=-1;
1117  for(int k=0; k<options_.tree_count_; ++k)
1118  {
1119  set_id=(set_id+1) % predictionSet.indices[0].size();
1120  typedef std::set<SampleRange<T1> > my_set;
1121  typedef typename my_set::iterator set_it;
1122  //typedef std::set<std::pair<int,SampleRange<T1> > >::iterator set_it;
1123  //Build a stack with all the ranges we have
1124  std::vector<std::pair<int,set_it> > stack;
1125  stack.clear();
1126  for(set_it i=predictionSet.ranges[set_id].begin();
1127  i!=predictionSet.ranges[set_id].end();++i)
1128  stack.push_back(std::pair<int,set_it>(2,i));
1129  //get weights predicted by single tree
1130  int num_decisions=0;
1131  while(!stack.empty())
1132  {
1133  set_it range=stack.back().second;
1134  int index=stack.back().first;
1135  stack.pop_back();
1136  ++num_decisions;
1137 
1138  if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1139  {
1140  ArrayVector<double>::iterator weights=Node<e_ConstProbNode>(trees_[k].topology_,
1141  trees_[k].parameters_,
1142  index).prob_begin();
1143  for(int i=range->start;i!=range->end;++i)
1144  {
1145  //update votecount.
1146  for(int l=0; l<ext_param_.class_count_; ++l)
1147  {
1148  prob(predictionSet.indices[set_id][i], l) += static_cast<T2>(weights[l]);
1149  //every weight in totalWeight.
1150  totalWeights[predictionSet.indices[set_id][i]] += static_cast<T1>(weights[l]);
1151  }
1152  }
1153  }
1154 
1155  else
1156  {
1157  if(trees_[k].topology_[index]!=i_ThresholdNode)
1158  {
1159  throw std::runtime_error("predicting with online prediction sets is only supported for RFs with threshold nodes");
1160  }
1161  Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1162  if(range->min_boundaries[node.column()]>=node.threshold())
1163  {
1164  //Everything goes to right child
1165  stack.push_back(std::pair<int,set_it>(node.child(1),range));
1166  continue;
1167  }
1168  if(range->max_boundaries[node.column()]<node.threshold())
1169  {
1170  //Everything goes to the left child
1171  stack.push_back(std::pair<int,set_it>(node.child(0),range));
1172  continue;
1173  }
1174  //We have to split at this node
1175  SampleRange<T1> new_range=*range;
1176  new_range.min_boundaries[node.column()]=FLT_MAX;
1177  range->max_boundaries[node.column()]=-FLT_MAX;
1178  new_range.start=new_range.end=range->end;
1179  int i=range->start;
1180  while(i!=range->end)
1181  {
1182  //Decide for range->indices[i]
1183  if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1184  {
1185  new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1186  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1187  --range->end;
1188  --new_range.start;
1189  std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1190 
1191  }
1192  else
1193  {
1194  range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1195  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1196  ++i;
1197  }
1198  }
1199  //The old one ...
1200  if(range->start==range->end)
1201  {
1202  predictionSet.ranges[set_id].erase(range);
1203  }
1204  else
1205  {
1206  stack.push_back(std::pair<int,set_it>(node.child(0),range));
1207  }
1208  //And the new one ...
1209  if(new_range.start!=new_range.end)
1210  {
1211  std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1212  stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1213  }
1214  }
1215  }
1216  predictionSet.cumulativePredTime[k]=num_decisions;
1217  }
1218  for(unsigned int i=0;i<totalWeights.size();++i)
1219  {
1220  double test=0.0;
1221  //Normalise votes in each row by total VoteCount (totalWeight
1222  for(int l=0; l<ext_param_.class_count_; ++l)
1223  {
1224  test+=prob(i,l);
1225  prob(i, l) /= totalWeights[i];
1226  }
1227  assert(test==totalWeights[i]);
1228  assert(totalWeights[i]>0.0);
1229  }
1230 }
1231 
1232 template <class LabelType, class PreprocessorTag>
1233 template <class U, class C1, class T, class C2, class Stop_t>
1235  ::predictProbabilities(MultiArrayView<2, U, C1>const & features,
1236  MultiArrayView<2, T, C2> & prob,
1237  Stop_t & stop_) const
1238 {
1239  //Features are n xp
1240  //prob is n x NumOfLabel probability for each feature in each class
1241 
1242  vigra_precondition(rowCount(features) == rowCount(prob),
1243  "RandomForestn::predictProbabilities():"
1244  " Feature matrix and probability matrix size mismatch.");
1245 
1246  // num of features must be bigger than num of features in Random forest training
1247  // but why bigger?
1248  vigra_precondition( columnCount(features) >= ext_param_.column_count_,
1249  "RandomForestn::predictProbabilities():"
1250  " Too few columns in feature matrix.");
1251  vigra_precondition( columnCount(prob)
1252  == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1253  "RandomForestn::predictProbabilities():"
1254  " Probability matrix must have as many columns as there are classes.");
1255 
1256  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1257  Default_Stop_t default_stop(options_);
1258  typename RF_CHOOSER(Stop_t)::type & stop
1259  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1260  #undef RF_CHOOSER
1261  stop.set_external_parameters(ext_param_, tree_count());
1262  prob.init(NumericTraits<T>::zero());
1263  /* This code was originally there for testing early stopping
1264  * - we wanted the order of the trees to be randomized
1265  if(tree_indices_.size() != 0)
1266  {
1267  std::random_shuffle(tree_indices_.begin(),
1268  tree_indices_.end());
1269  }
1270  */
1271  //Classify for each row.
1272  for(int row=0; row < rowCount(features); ++row)
1273  {
1274  MultiArrayView<2, U, StridedArrayTag> currentRow(rowVector(features, row));
1275 
1276  // when the features contain an NaN, the instance doesn't belong to any class
1277  // => indicate this by returning a zero probability array.
1278  if(detail::contains_nan(currentRow))
1279  {
1280  rowVector(prob, row).init(0.0);
1281  continue;
1282  }
1283 
1284  ArrayVector<double>::const_iterator weights;
1285 
1286  //totalWeight == totalVoteCount!
1287  double totalWeight = 0.0;
1288 
1289  //Let each tree classify...
1290  for(int k=0; k<options_.tree_count_; ++k)
1291  {
1292  //get weights predicted by single tree
1293  weights = trees_[k /*tree_indices_[k]*/].predict(currentRow);
1294 
1295  //update votecount.
1296  int weighted = options_.predict_weighted_;
1297  for(int l=0; l<ext_param_.class_count_; ++l)
1298  {
1299  double cur_w = weights[l] * (weighted * (*(weights-1))
1300  + (1-weighted));
1301  prob(row, l) += static_cast<T>(cur_w);
1302  //every weight in totalWeight.
1303  totalWeight += cur_w;
1304  }
1305  if(stop.after_prediction(weights,
1306  k,
1307  rowVector(prob, row),
1308  totalWeight))
1309  {
1310  break;
1311  }
1312  }
1313 
1314  //Normalise votes in each row by total VoteCount (totalWeight
1315  for(int l=0; l< ext_param_.class_count_; ++l)
1316  {
1317  prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1318  }
1319  }
1320 
1321 }
1322 
1323 template <class LabelType, class PreprocessorTag>
1324 template <class U, class C1, class T, class C2>
1325 void RandomForest<LabelType, PreprocessorTag>
1326  ::predictRaw(MultiArrayView<2, U, C1>const & features,
1327  MultiArrayView<2, T, C2> & prob) const
1328 {
1329  //Features are n xp
1330  //prob is n x NumOfLabel probability for each feature in each class
1331 
1332  vigra_precondition(rowCount(features) == rowCount(prob),
1333  "RandomForestn::predictProbabilities():"
1334  " Feature matrix and probability matrix size mismatch.");
1335 
1336  // num of features must be bigger than num of features in Random forest training
1337  // but why bigger?
1338  vigra_precondition( columnCount(features) >= ext_param_.column_count_,
1339  "RandomForestn::predictProbabilities():"
1340  " Too few columns in feature matrix.");
1341  vigra_precondition( columnCount(prob)
1342  == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1343  "RandomForestn::predictProbabilities():"
1344  " Probability matrix must have as many columns as there are classes.");
1345 
1346  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1347  prob.init(NumericTraits<T>::zero());
1348  /* This code was originally there for testing early stopping
1349  * - we wanted the order of the trees to be randomized
1350  if(tree_indices_.size() != 0)
1351  {
1352  std::random_shuffle(tree_indices_.begin(),
1353  tree_indices_.end());
1354  }
1355  */
1356  //Classify for each row.
1357  for(int row=0; row < rowCount(features); ++row)
1358  {
1359  ArrayVector<double>::const_iterator weights;
1360 
1361  //totalWeight == totalVoteCount!
1362  double totalWeight = 0.0;
1363 
1364  //Let each tree classify...
1365  for(int k=0; k<options_.tree_count_; ++k)
1366  {
1367  //get weights predicted by single tree
1368  weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row));
1369 
1370  //update votecount.
1371  int weighted = options_.predict_weighted_;
1372  for(int l=0; l<ext_param_.class_count_; ++l)
1373  {
1374  double cur_w = weights[l] * (weighted * (*(weights-1))
1375  + (1-weighted));
1376  prob(row, l) += static_cast<T>(cur_w);
1377  //every weight in totalWeight.
1378  totalWeight += cur_w;
1379  }
1380  }
1381  }
1382  prob/= options_.tree_count_;
1383 
1384 }
1385 
1386 } // namespace vigra
1387 
1388 #include "random_forest/rf_algorithm.hxx"
1389 #endif // VIGRA_RANDOM_FOREST_HXX
Definition: rf_region.hxx:57
void set_ext_param(ProblemSpec_t const &in)
set external parameters
Definition: random_forest.hxx:278
int class_count() const
return number of classes used while training.
Definition: random_forest.hxx:342
Definition: rf_nodeproxy.hxx:626
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag
Definition: rf_common.hxx:131
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:671
Definition: rf_preprocessing.hxx:63
int feature_count() const
return number of features used while training.
Definition: random_forest.hxx:323
int column_count() const
return number of features used while training.
Definition: random_forest.hxx:334
Create random samples from a sequence of indices.
Definition: sampling.hxx:232
const difference_type & shape() const
Definition: multi_array.hxx:1648
void sample()
Definition: sampling.hxx:467
Definition: rf_split.hxx:993
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, LabelType nanLabel) const
predict multiple labels with given features
Definition: random_forest.hxx:614
const_iterator begin() const
Definition: array_vector.hxx:223
problem specification class for the random forest.
Definition: rf_common.hxx:538
RandomForest(Options_t const &options=Options_t(), ProblemSpec_t const &ext_param=ProblemSpec_t())
default constructor
Definition: random_forest.hxx:197
Standard early stopping criterion.
Definition: rf_common.hxx:885
Definition: random.hxx:669
ProblemSpec_t const & ext_param() const
return external parameters for viewing
Definition: random_forest.hxx:260
DecisionTree_t & tree(int index)
access trees
Definition: random_forest.hxx:315
DecisionTree_t const & tree(int index) const
access const trees
Definition: random_forest.hxx:308
Options_t & set_options()
access random forest options
Definition: random_forest.hxx:291
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels)
learn on data with default configuration
Definition: random_forest.hxx:471
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const
predict the class probabilities for multiple labels
Definition: random_forest.hxx:675
Random forest version 2 (see also vigra::rf3::RandomForest for version 3)
Definition: random_forest.hxx:147
Options_t const & options() const
access const random forest options
Definition: random_forest.hxx:301
Definition: rf_visitors.hxx:254
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
LabelType predictLabel(MultiArrayView< 2, U, C >const &features, Stop &stop) const
predict a label given a feature.
Definition: random_forest.hxx:1052
Definition: rf_visitors.hxx:583
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels) const
predict multiple labels with given features
Definition: random_forest.hxx:590
SamplerOptions & stratified(bool in=true)
Draw equally many samples from each "stratum". A stratum is a group of like entities, e.g. pixels belonging to the same object class. This is useful to create balanced samples when the class probabilities are very unbalanced (e.g. when there are many background and few foreground pixels). Stratified sampling thus avoids that a trained classifier is biased towards the majority class.
Definition: sampling.hxx:141
SamplerOptions & withReplacement(bool in=true)
Sample from training population with replacement.
Definition: sampling.hxx:83
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:697
int tree_count() const
return number of trees
Definition: random_forest.hxx:349
void reLearnTree(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, int treeId, Visitor_t visitor_, Split_t split_, Stop_t stop_, Random_t &random)
Definition: random_forest.hxx:838
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:684
RandomForest(int treeCount, TopologyIterator topology_begin, ParameterIterator parameter_begin, ProblemSpec_t const &problem_spec, Options_t const &options=Options_t())
Create RF from external source.
Definition: random_forest.hxx:228
Definition: random.hxx:336
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:704
Options object for the random forest.
Definition: rf_common.hxx:170
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, Stop &stop) const
predict multiple labels with given features
Definition: random_forest.hxx:639
MultiArrayView & init(const U &init)
Definition: multi_array.hxx:1206
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator
Definition: random_forest.hxx:941
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob, Stop &stop) const
predict the class probabilities for multiple labels
Definition: rf_visitors.hxx:234

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.1 (Fri May 19 2017)