[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_visitors.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 #ifndef RF_VISITORS_HXX
36 #define RF_VISITORS_HXX
37 
38 #ifdef HasHDF5
39 # include "vigra/hdf5impex.hxx"
40 #endif // HasHDF5
41 #include <vigra/windows.h>
42 #include <iostream>
43 #include <iomanip>
44 
45 #include <vigra/metaprogramming.hxx>
46 #include <vigra/multi_pointoperators.hxx>
47 #include <vigra/timing.hxx>
48 
49 namespace vigra
50 {
51 namespace rf
52 {
53 /** \brief Visitors to extract information during training of \ref vigra::RandomForest version 2.
54 
55  \ingroup MachineLearning
56 
57  This namespace contains all classes and methods related to extracting information during
58  learning of the random forest. All Visitors share the same interface defined in
59  visitors::VisitorBase. The member methods are invoked at certain points of the main code in
60  the order they were supplied.
61 
62  For the Random Forest the Visitor concept is implemented as a statically linked list
63  (Using templates). Each Visitor object is encapsulated in a detail::VisitorNode object. The
64  VisitorNode object calls the Next Visitor after one of its visit() methods have terminated.
65 
66  To simplify usage create_visitor() factory methods are supplied.
67  Use the create_visitor() method to supply visitor objects to the RandomForest::learn() method.
68  It is possible to supply more than one visitor. They will then be invoked in serial order.
69 
70  The calculated information are stored as public data members of the class. - see documentation
71  of the individual visitors
72 
73  While creating a new visitor the new class should therefore publicly inherit from this class
74  (i.e.: see visitors::OOB_Error).
75 
76  \code
77 
78  typedef xxx feature_t \\ replace xxx with whichever type
79  typedef yyy label_t \\ meme chose.
80  MultiArrayView<2, feature_t> f = get_some_features();
81  MultiArrayView<2, label_t> l = get_some_labels();
82  RandomForest<> rf()
83 
84  //calculate OOB Error
85  visitors::OOB_Error oob_v;
86  //calculate Variable Importance
87  visitors::VariableImportanceVisitor varimp_v;
88 
89  double oob_error = rf.learn(f, l, visitors::create_visitor(oob_v, varimp_v);
90  //the data can be found in the attributes of oob_v and varimp_v now
91 
92  \endcode
93 */
94 namespace visitors
95 {
96 
97 
98 /** Base Class from which all Visitors derive. Can be used as a template to create new
99  * Visitors.
100  */
102 {
103  public:
104  bool active_;
105  bool is_active()
106  {
107  return active_;
108  }
109 
110  bool has_value()
111  {
112  return false;
113  }
114 
115  VisitorBase()
116  : active_(true)
117  {}
118 
119  void deactivate()
120  {
121  active_ = false;
122  }
123  void activate()
124  {
125  active_ = true;
126  }
127 
128  /** do something after the the Split has decided how to process the Region
129  * (Stack entry)
130  *
131  * \param tree reference to the tree that is currently being learned
132  * \param split reference to the split object
133  * \param parent current stack entry which was used to decide the split
134  * \param leftChild left stack entry that will be pushed
135  * \param rightChild
136  * right stack entry that will be pushed.
137  * \param features features matrix
138  * \param labels label matrix
139  * \sa RF_Traits::StackEntry_t
140  */
141  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
142  void visit_after_split( Tree & tree,
143  Split & split,
144  Region & parent,
145  Region & leftChild,
146  Region & rightChild,
147  Feature_t & features,
148  Label_t & labels)
149  {
150  ignore_argument(tree,split,parent,leftChild,rightChild,features,labels);
151  }
152 
153  /** do something after each tree has been learned
154  *
155  * \param rf reference to the random forest object that called this
156  * visitor
157  * \param pr reference to the preprocessor that processed the input
158  * \param sm reference to the sampler object
159  * \param st reference to the first stack entry
160  * \param index index of current tree
161  */
162  template<class RF, class PR, class SM, class ST>
163  void visit_after_tree(RF & rf, PR & pr, SM & sm, ST & st, int index)
164  {
165  ignore_argument(rf,pr,sm,st,index);
166  }
167 
168  /** do something after all trees have been learned
169  *
170  * \param rf reference to the random forest object that called this
171  * visitor
172  * \param pr reference to the preprocessor that processed the input
173  */
174  template<class RF, class PR>
175  void visit_at_end(RF const & rf, PR const & pr)
176  {
177  ignore_argument(rf,pr);
178  }
179 
180  /** do something before learning starts
181  *
182  * \param rf reference to the random forest object that called this
183  * visitor
184  * \param pr reference to the Processor class used.
185  */
186  template<class RF, class PR>
187  void visit_at_beginning(RF const & rf, PR const & pr)
188  {
189  ignore_argument(rf,pr);
190  }
191  /** do some thing while traversing tree after it has been learned
192  * (external nodes)
193  *
194  * \param tr reference to the tree object that called this visitor
195  * \param index index in the topology_ array we currently are at
196  * \param node_t type of node we have (will be e_.... - )
197  * \param features feature matrix
198  * \sa NodeTags;
199  *
200  * you can create the node by using a switch on node_tag and using the
201  * corresponding Node objects. Or - if you do not care about the type
202  * use the NodeBase class.
203  */
204  template<class TR, class IntT, class TopT,class Feat>
205  void visit_external_node(TR & tr, IntT index, TopT node_t, Feat & features)
206  {
207  ignore_argument(tr,index,node_t,features);
208  }
209 
210  /** do something when visiting a internal node after it has been learned
211  *
212  * \sa visit_external_node
213  */
214  template<class TR, class IntT, class TopT,class Feat>
215  void visit_internal_node(TR & /* tr */, IntT /* index */, TopT /* node_t */, Feat & /* features */)
216  {}
217 
218  /** return a double value. The value of the first
219  * visitor encountered that has a return value is returned with the
220  * RandomForest::learn() method - or -1.0 if no return value visitor
221  * existed. This functionality basically only exists so that the
222  * OOB - visitor can return the oob error rate like in the old version
223  * of the random forest.
224  */
225  double return_val()
226  {
227  return -1.0;
228  }
229 };
230 
231 
232 /** Last Visitor that should be called to stop the recursion.
233  */
235 {
236  public:
237  bool has_value()
238  {
239  return true;
240  }
241  double return_val()
242  {
243  return -1.0;
244  }
245 };
246 namespace detail
247 {
248 /** Container elements of the statically linked Visitor list.
249  *
250  * use the create_visitor() factory functions to create visitors up to size 10;
251  *
252  */
253 template <class Visitor, class Next = StopVisiting>
255 {
256  public:
257 
258  StopVisiting stop_;
259  Next next_;
260  Visitor & visitor_;
261  VisitorNode(Visitor & visitor, Next & next)
262  :
263  next_(next), visitor_(visitor)
264  {}
265 
266  VisitorNode(Visitor & visitor)
267  :
268  next_(stop_), visitor_(visitor)
269  {}
270 
271  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
272  void visit_after_split( Tree & tree,
273  Split & split,
274  Region & parent,
275  Region & leftChild,
276  Region & rightChild,
277  Feature_t & features,
278  Label_t & labels)
279  {
280  if(visitor_.is_active())
281  visitor_.visit_after_split(tree, split,
282  parent, leftChild, rightChild,
283  features, labels);
284  next_.visit_after_split(tree, split, parent, leftChild, rightChild,
285  features, labels);
286  }
287 
288  template<class RF, class PR, class SM, class ST>
289  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
290  {
291  if(visitor_.is_active())
292  visitor_.visit_after_tree(rf, pr, sm, st, index);
293  next_.visit_after_tree(rf, pr, sm, st, index);
294  }
295 
296  template<class RF, class PR>
297  void visit_at_beginning(RF & rf, PR & pr)
298  {
299  if(visitor_.is_active())
300  visitor_.visit_at_beginning(rf, pr);
301  next_.visit_at_beginning(rf, pr);
302  }
303  template<class RF, class PR>
304  void visit_at_end(RF & rf, PR & pr)
305  {
306  if(visitor_.is_active())
307  visitor_.visit_at_end(rf, pr);
308  next_.visit_at_end(rf, pr);
309  }
310 
311  template<class TR, class IntT, class TopT,class Feat>
312  void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
313  {
314  if(visitor_.is_active())
315  visitor_.visit_external_node(tr, index, node_t,features);
316  next_.visit_external_node(tr, index, node_t,features);
317  }
318  template<class TR, class IntT, class TopT,class Feat>
319  void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
320  {
321  if(visitor_.is_active())
322  visitor_.visit_internal_node(tr, index, node_t,features);
323  next_.visit_internal_node(tr, index, node_t,features);
324  }
325 
326  double return_val()
327  {
328  if(visitor_.is_active() && visitor_.has_value())
329  return visitor_.return_val();
330  return next_.return_val();
331  }
332 };
333 
334 } //namespace detail
335 
336 //////////////////////////////////////////////////////////////////////////////
337 // Visitor Factory function up to 10 visitors //
338 //////////////////////////////////////////////////////////////////////////////
339 
340 /** factory method to to be used with RandomForest::learn()
341  */
342 template<class A>
345 {
346  typedef detail::VisitorNode<A> _0_t;
347  _0_t _0(a);
348  return _0;
349 }
350 
351 
352 /** factory method to to be used with RandomForest::learn()
353  */
354 template<class A, class B>
355 detail::VisitorNode<A, detail::VisitorNode<B> >
356 create_visitor(A & a, B & b)
357 {
358  typedef detail::VisitorNode<B> _1_t;
359  _1_t _1(b);
360  typedef detail::VisitorNode<A, _1_t> _0_t;
361  _0_t _0(a, _1);
362  return _0;
363 }
364 
365 
366 /** factory method to to be used with RandomForest::learn()
367  */
368 template<class A, class B, class C>
369 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
370 create_visitor(A & a, B & b, C & c)
371 {
372  typedef detail::VisitorNode<C> _2_t;
373  _2_t _2(c);
374  typedef detail::VisitorNode<B, _2_t> _1_t;
375  _1_t _1(b, _2);
376  typedef detail::VisitorNode<A, _1_t> _0_t;
377  _0_t _0(a, _1);
378  return _0;
379 }
380 
381 
382 /** factory method to to be used with RandomForest::learn()
383  */
384 template<class A, class B, class C, class D>
385 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
386  detail::VisitorNode<D> > > >
387 create_visitor(A & a, B & b, C & c, D & d)
388 {
389  typedef detail::VisitorNode<D> _3_t;
390  _3_t _3(d);
391  typedef detail::VisitorNode<C, _3_t> _2_t;
392  _2_t _2(c, _3);
393  typedef detail::VisitorNode<B, _2_t> _1_t;
394  _1_t _1(b, _2);
395  typedef detail::VisitorNode<A, _1_t> _0_t;
396  _0_t _0(a, _1);
397  return _0;
398 }
399 
400 
401 /** factory method to to be used with RandomForest::learn()
402  */
403 template<class A, class B, class C, class D, class E>
404 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
405  detail::VisitorNode<D, detail::VisitorNode<E> > > > >
406 create_visitor(A & a, B & b, C & c,
407  D & d, E & e)
408 {
409  typedef detail::VisitorNode<E> _4_t;
410  _4_t _4(e);
411  typedef detail::VisitorNode<D, _4_t> _3_t;
412  _3_t _3(d, _4);
413  typedef detail::VisitorNode<C, _3_t> _2_t;
414  _2_t _2(c, _3);
415  typedef detail::VisitorNode<B, _2_t> _1_t;
416  _1_t _1(b, _2);
417  typedef detail::VisitorNode<A, _1_t> _0_t;
418  _0_t _0(a, _1);
419  return _0;
420 }
421 
422 
423 /** factory method to to be used with RandomForest::learn()
424  */
425 template<class A, class B, class C, class D, class E,
426  class F>
427 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
428  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
429 create_visitor(A & a, B & b, C & c,
430  D & d, E & e, F & f)
431 {
432  typedef detail::VisitorNode<F> _5_t;
433  _5_t _5(f);
434  typedef detail::VisitorNode<E, _5_t> _4_t;
435  _4_t _4(e, _5);
436  typedef detail::VisitorNode<D, _4_t> _3_t;
437  _3_t _3(d, _4);
438  typedef detail::VisitorNode<C, _3_t> _2_t;
439  _2_t _2(c, _3);
440  typedef detail::VisitorNode<B, _2_t> _1_t;
441  _1_t _1(b, _2);
442  typedef detail::VisitorNode<A, _1_t> _0_t;
443  _0_t _0(a, _1);
444  return _0;
445 }
446 
447 
448 /** factory method to to be used with RandomForest::learn()
449  */
450 template<class A, class B, class C, class D, class E,
451  class F, class G>
452 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
453  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
454  detail::VisitorNode<G> > > > > > >
455 create_visitor(A & a, B & b, C & c,
456  D & d, E & e, F & f, G & g)
457 {
458  typedef detail::VisitorNode<G> _6_t;
459  _6_t _6(g);
460  typedef detail::VisitorNode<F, _6_t> _5_t;
461  _5_t _5(f, _6);
462  typedef detail::VisitorNode<E, _5_t> _4_t;
463  _4_t _4(e, _5);
464  typedef detail::VisitorNode<D, _4_t> _3_t;
465  _3_t _3(d, _4);
466  typedef detail::VisitorNode<C, _3_t> _2_t;
467  _2_t _2(c, _3);
468  typedef detail::VisitorNode<B, _2_t> _1_t;
469  _1_t _1(b, _2);
470  typedef detail::VisitorNode<A, _1_t> _0_t;
471  _0_t _0(a, _1);
472  return _0;
473 }
474 
475 
476 /** factory method to to be used with RandomForest::learn()
477  */
478 template<class A, class B, class C, class D, class E,
479  class F, class G, class H>
480 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
481  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
482  detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
483 create_visitor(A & a, B & b, C & c,
484  D & d, E & e, F & f,
485  G & g, H & h)
486 {
487  typedef detail::VisitorNode<H> _7_t;
488  _7_t _7(h);
489  typedef detail::VisitorNode<G, _7_t> _6_t;
490  _6_t _6(g, _7);
491  typedef detail::VisitorNode<F, _6_t> _5_t;
492  _5_t _5(f, _6);
493  typedef detail::VisitorNode<E, _5_t> _4_t;
494  _4_t _4(e, _5);
495  typedef detail::VisitorNode<D, _4_t> _3_t;
496  _3_t _3(d, _4);
497  typedef detail::VisitorNode<C, _3_t> _2_t;
498  _2_t _2(c, _3);
499  typedef detail::VisitorNode<B, _2_t> _1_t;
500  _1_t _1(b, _2);
501  typedef detail::VisitorNode<A, _1_t> _0_t;
502  _0_t _0(a, _1);
503  return _0;
504 }
505 
506 
507 /** factory method to to be used with RandomForest::learn()
508  */
509 template<class A, class B, class C, class D, class E,
510  class F, class G, class H, class I>
511 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
512  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
513  detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
514 create_visitor(A & a, B & b, C & c,
515  D & d, E & e, F & f,
516  G & g, H & h, I & i)
517 {
518  typedef detail::VisitorNode<I> _8_t;
519  _8_t _8(i);
520  typedef detail::VisitorNode<H, _8_t> _7_t;
521  _7_t _7(h, _8);
522  typedef detail::VisitorNode<G, _7_t> _6_t;
523  _6_t _6(g, _7);
524  typedef detail::VisitorNode<F, _6_t> _5_t;
525  _5_t _5(f, _6);
526  typedef detail::VisitorNode<E, _5_t> _4_t;
527  _4_t _4(e, _5);
528  typedef detail::VisitorNode<D, _4_t> _3_t;
529  _3_t _3(d, _4);
530  typedef detail::VisitorNode<C, _3_t> _2_t;
531  _2_t _2(c, _3);
532  typedef detail::VisitorNode<B, _2_t> _1_t;
533  _1_t _1(b, _2);
534  typedef detail::VisitorNode<A, _1_t> _0_t;
535  _0_t _0(a, _1);
536  return _0;
537 }
538 
539 /** factory method to to be used with RandomForest::learn()
540  */
541 template<class A, class B, class C, class D, class E,
542  class F, class G, class H, class I, class J>
543 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
544  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
545  detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
546  detail::VisitorNode<J> > > > > > > > > >
547 create_visitor(A & a, B & b, C & c,
548  D & d, E & e, F & f,
549  G & g, H & h, I & i,
550  J & j)
551 {
552  typedef detail::VisitorNode<J> _9_t;
553  _9_t _9(j);
554  typedef detail::VisitorNode<I, _9_t> _8_t;
555  _8_t _8(i, _9);
556  typedef detail::VisitorNode<H, _8_t> _7_t;
557  _7_t _7(h, _8);
558  typedef detail::VisitorNode<G, _7_t> _6_t;
559  _6_t _6(g, _7);
560  typedef detail::VisitorNode<F, _6_t> _5_t;
561  _5_t _5(f, _6);
562  typedef detail::VisitorNode<E, _5_t> _4_t;
563  _4_t _4(e, _5);
564  typedef detail::VisitorNode<D, _4_t> _3_t;
565  _3_t _3(d, _4);
566  typedef detail::VisitorNode<C, _3_t> _2_t;
567  _2_t _2(c, _3);
568  typedef detail::VisitorNode<B, _2_t> _1_t;
569  _1_t _1(b, _2);
570  typedef detail::VisitorNode<A, _1_t> _0_t;
571  _0_t _0(a, _1);
572  return _0;
573 }
574 
575 //////////////////////////////////////////////////////////////////////////////
576 // Visitors of communal interest. //
577 //////////////////////////////////////////////////////////////////////////////
578 
579 
580 /** Visitor to gain information, later needed for online learning.
581  */
582 
584 {
585 public:
586  //Set if we adjust thresholds
587  bool adjust_thresholds;
588  //Current tree id
589  int tree_id;
590  //Last node id for finding parent
591  int last_node_id;
592  //Need to now the label for interior node visiting
593  vigra::Int32 current_label;
594  //marginal distribution for interior nodes
595  //
597  adjust_thresholds(false), tree_id(0), last_node_id(0), current_label(0)
598  {}
599  struct MarginalDistribution
600  {
601  ArrayVector<Int32> leftCounts;
602  Int32 leftTotalCounts;
603  ArrayVector<Int32> rightCounts;
604  Int32 rightTotalCounts;
605  double gap_left;
606  double gap_right;
607  };
609 
610  //All information for one tree
611  struct TreeOnlineInformation
612  {
613  std::vector<MarginalDistribution> mag_distributions;
614  std::vector<IndexList> index_lists;
615  //map for linear index of mag_distributions
616  std::map<int,int> interior_to_index;
617  //map for linear index of index_lists
618  std::map<int,int> exterior_to_index;
619  };
620 
621  //All trees
622  std::vector<TreeOnlineInformation> trees_online_information;
623 
624  /** Initialize, set the number of trees
625  */
626  template<class RF,class PR>
627  void visit_at_beginning(RF & rf,const PR & /* pr */)
628  {
629  tree_id=0;
630  trees_online_information.resize(rf.options_.tree_count_);
631  }
632 
633  /** Reset a tree
634  */
635  void reset_tree(int tree_id)
636  {
637  trees_online_information[tree_id].mag_distributions.clear();
638  trees_online_information[tree_id].index_lists.clear();
639  trees_online_information[tree_id].interior_to_index.clear();
640  trees_online_information[tree_id].exterior_to_index.clear();
641  }
642 
643  /** simply increase the tree count
644  */
645  template<class RF, class PR, class SM, class ST>
646  void visit_after_tree(RF & /* rf */, PR & /* pr */, SM & /* sm */, ST & /* st */, int /* index */)
647  {
648  tree_id++;
649  }
650 
651  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
652  void visit_after_split( Tree & tree,
653  Split & split,
654  Region & parent,
655  Region & leftChild,
656  Region & rightChild,
657  Feature_t & features,
658  Label_t & /* labels */)
659  {
660  int linear_index;
661  int addr=tree.topology_.size();
662  if(split.createNode().typeID() == i_ThresholdNode)
663  {
664  if(adjust_thresholds)
665  {
666  //Store marginal distribution
667  linear_index=trees_online_information[tree_id].mag_distributions.size();
668  trees_online_information[tree_id].interior_to_index[addr]=linear_index;
669  trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
670 
671  trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
672  trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
673 
674  trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
675  trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
676  //Store the gap
677  double gap_left,gap_right;
678  int i;
679  gap_left=features(leftChild[0],split.bestSplitColumn());
680  for(i=1;i<leftChild.size();++i)
681  if(features(leftChild[i],split.bestSplitColumn())>gap_left)
682  gap_left=features(leftChild[i],split.bestSplitColumn());
683  gap_right=features(rightChild[0],split.bestSplitColumn());
684  for(i=1;i<rightChild.size();++i)
685  if(features(rightChild[i],split.bestSplitColumn())<gap_right)
686  gap_right=features(rightChild[i],split.bestSplitColumn());
687  trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
688  trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
689  }
690  }
691  else
692  {
693  //Store index list
694  linear_index=trees_online_information[tree_id].index_lists.size();
695  trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
696 
697  trees_online_information[tree_id].index_lists.push_back(IndexList());
698 
699  trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
700  std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
701  }
702  }
703  void add_to_index_list(int tree,int node,int index)
704  {
705  if(!this->active_)
706  return;
707  TreeOnlineInformation &ti=trees_online_information[tree];
708  ti.index_lists[ti.exterior_to_index[node]].push_back(index);
709  }
710  void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index)
711  {
712  if(!this->active_)
713  return;
714  trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
715  trees_online_information[src_tree].exterior_to_index.erase(src_index);
716  }
717  /** do something when visiting a internal node during getToLeaf
718  *
719  * remember as last node id, for finding the parent of the last external node
720  * also: adjust class counts and borders
721  */
722  template<class TR, class IntT, class TopT,class Feat>
723  void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
724  {
725  last_node_id=index;
726  if(adjust_thresholds)
727  {
728  vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes");
729  //Check if we are in the gap
730  double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
731  TreeOnlineInformation &ti=trees_online_information[tree_id];
732  MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
733  if(value>m.gap_left && value<m.gap_right)
734  {
735  //Check which site we want to go
736  if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
737  {
738  //We want to go left
739  m.gap_left=value;
740  }
741  else
742  {
743  //We want to go right
744  m.gap_right=value;
745  }
746  Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
747  }
748  //Adjust class counts
749  if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
750  {
751  ++m.rightTotalCounts;
752  ++m.rightCounts[current_label];
753  }
754  else
755  {
756  ++m.leftTotalCounts;
757  ++m.rightCounts[current_label];
758  }
759  }
760  }
761  /** do something when visiting a extern node during getToLeaf
762  *
763  * Store the new index!
764  */
765 };
766 
767 //////////////////////////////////////////////////////////////////////////////
768 // Out of Bag Error estimates //
769 //////////////////////////////////////////////////////////////////////////////
770 
771 
772 /** Visitor that calculates the oob error of each individual randomized
773  * decision tree.
774  *
775  * After training a tree, all those samples that are OOB for this particular tree
776  * are put down the tree and the error estimated.
777  * the per tree oob error is the average of the individual error estimates.
778  * (oobError = average error of one randomized tree)
779  * Note: This is Not the OOB - Error estimate suggested by Breiman (See OOB_Error
780  * visitor)
781  */
783 {
784 public:
785  /** Average error of one randomized decision tree
786  */
787  double oobError;
788 
789  int totalOobCount;
790  ArrayVector<int> oobCount,oobErrorCount;
791 
793  : oobError(0.0),
794  totalOobCount(0)
795  {}
796 
797 
798  bool has_value()
799  {
800  return true;
801  }
802 
803 
804  /** does the basic calculation per tree*/
805  template<class RF, class PR, class SM, class ST>
806  void visit_after_tree(RF & rf, PR & pr, SM & sm, ST &, int index)
807  {
808  //do the first time called.
809  if(int(oobCount.size()) != rf.ext_param_.row_count_)
810  {
811  oobCount.resize(rf.ext_param_.row_count_, 0);
812  oobErrorCount.resize(rf.ext_param_.row_count_, 0);
813  }
814  // go through the samples
815  for(int l = 0; l < rf.ext_param_.row_count_; ++l)
816  {
817  // if the lth sample is oob...
818  if(!sm.is_used()[l])
819  {
820  ++oobCount[l];
821  if( rf.tree(index)
822  .predictLabel(rowVector(pr.features(), l))
823  != pr.response()(l,0))
824  {
825  ++oobErrorCount[l];
826  }
827  }
828 
829  }
830  }
831 
832  /** Does the normalisation
833  */
834  template<class RF, class PR>
835  void visit_at_end(RF & rf, PR &)
836  {
837  // do some normalisation
838  for(int l=0; l < static_cast<int>(rf.ext_param_.row_count_); ++l)
839  {
840  if(oobCount[l])
841  {
842  oobError += double(oobErrorCount[l]) / oobCount[l];
843  ++totalOobCount;
844  }
845  }
846  oobError/=totalOobCount;
847  }
848 
849 };
850 
851 /** Visitor that calculates the oob error of the ensemble
852  *
853  * This rate serves as a quick estimate for the crossvalidation
854  * error rate.
855  * Here, each sample is put down the trees for which this sample
856  * is OOB, i.e., if sample #1 is OOB for trees 1, 3 and 5, we calculate
857  * the output using the ensemble consisting only of trees 1 3 and 5.
858  *
859  * Using normal bagged sampling each sample is OOB for approx. 33% of trees.
860  * The error rate obtained as such therefore corresponds to a crossvalidation
861  * rate obtained using a ensemble containing 33% of the trees.
862  */
863 class OOB_Error : public VisitorBase
864 {
865  typedef MultiArrayShape<2>::type Shp;
866  int class_count;
867  bool is_weighted;
868  MultiArray<2,double> tmp_prob;
869  public:
870 
871  MultiArray<2, double> prob_oob;
872  /** Ensemble oob error rate
873  */
874  double oob_breiman;
875 
876  MultiArray<2, double> oobCount;
877  ArrayVector< int> indices;
878  OOB_Error() : VisitorBase(), oob_breiman(0.0) {}
879 #ifdef HasHDF5
880  void save(std::string filen, std::string pathn)
881  {
882  if(*(pathn.end()-1) != '/')
883  pathn += "/";
884  const char* filename = filen.c_str();
885  MultiArray<2, double> temp(Shp(1,1), 0.0);
886  temp[0] = oob_breiman;
887  writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
888  }
889 #endif
890  // negative value if sample was ib, number indicates how often.
891  // value >=0 if sample was oob, 0 means fail 1, correct
892 
893  template<class RF, class PR>
894  void visit_at_beginning(RF & rf, PR &)
895  {
896  class_count = rf.class_count();
897  tmp_prob.reshape(Shp(1, class_count), 0);
898  prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
899  is_weighted = rf.options().predict_weighted_;
900  indices.resize(rf.ext_param().row_count_);
901  if(int(oobCount.size()) != rf.ext_param_.row_count_)
902  {
903  oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
904  }
905  for(int ii = 0; ii < rf.ext_param().row_count_; ++ii)
906  {
907  indices[ii] = ii;
908  }
909  }
910 
911  template<class RF, class PR, class SM, class ST>
912  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST &, int index)
913  {
914  // go through the samples
915  int total_oob =0;
916  // FIXME: magic number 10000: invoke special treatment when when msample << sample_count
917  // (i.e. the OOB sample ist very large)
918  // 40000: use at most 40000 OOB samples per class for OOB error estimate
919  if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
920  {
921  ArrayVector<int> oob_indices;
922  ArrayVector<int> cts(class_count, 0);
923  std::random_shuffle(indices.begin(), indices.end());
924  for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
925  {
926  if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
927  {
928  oob_indices.push_back(indices[ii]);
929  ++cts[pr.response()(indices[ii], 0)];
930  }
931  }
932  for(unsigned int ll = 0; ll < oob_indices.size(); ++ll)
933  {
934  // update number of trees in which current sample is oob
935  ++oobCount[oob_indices[ll]];
936 
937  // update number of oob samples in this tree.
938  ++total_oob;
939  // get the predicted votes ---> tmp_prob;
940  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll]));
941  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
942  rf.tree(index).parameters_,
943  pos);
944  tmp_prob.init(0);
945  for(int ii = 0; ii < class_count; ++ii)
946  {
947  tmp_prob[ii] = node.prob_begin()[ii];
948  }
949  if(is_weighted)
950  {
951  for(int ii = 0; ii < class_count; ++ii)
952  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
953  }
954  rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
955 
956  }
957  }else
958  {
959  for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
960  {
961  // if the lth sample is oob...
962  if(!sm.is_used()[ll])
963  {
964  // update number of trees in which current sample is oob
965  ++oobCount[ll];
966 
967  // update number of oob samples in this tree.
968  ++total_oob;
969  // get the predicted votes ---> tmp_prob;
970  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
971  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
972  rf.tree(index).parameters_,
973  pos);
974  tmp_prob.init(0);
975  for(int ii = 0; ii < class_count; ++ii)
976  {
977  tmp_prob[ii] = node.prob_begin()[ii];
978  }
979  if(is_weighted)
980  {
981  for(int ii = 0; ii < class_count; ++ii)
982  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
983  }
984  rowVector(prob_oob, ll) += tmp_prob;
985  }
986  }
987  }
988  // go through the ib samples;
989  }
990 
991  /** Normalise variable importance after the number of trees is known.
992  */
993  template<class RF, class PR>
994  void visit_at_end(RF & rf, PR & pr)
995  {
996  // ullis original metric and breiman style stuff
997  int totalOobCount =0;
998  int breimanstyle = 0;
999  for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1000  {
1001  if(oobCount[ll])
1002  {
1003  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1004  ++breimanstyle;
1005  ++totalOobCount;
1006  }
1007  }
1008  oob_breiman = double(breimanstyle)/totalOobCount;
1009  }
1010 };
1011 
1012 
1013 /** Visitor that calculates different OOB error statistics
1014  */
1016 {
1017  typedef MultiArrayShape<2>::type Shp;
1018  int class_count;
1019  bool is_weighted;
1020  MultiArray<2,double> tmp_prob;
1021  public:
1022 
1023  /** OOB Error rate of each individual tree
1024  */
1026  /** Mean of oob_per_tree
1027  */
1028  double oob_mean;
1029  /**Standard deviation of oob_per_tree
1030  */
1031  double oob_std;
1032 
1033  MultiArray<2, double> prob_oob;
1034  /** Ensemble OOB error
1035  *
1036  * \sa OOB_Error
1037  */
1038  double oob_breiman;
1039 
1040  MultiArray<2, double> oobCount;
1041  MultiArray<2, double> oobErrorCount;
1042  /** Per Tree OOB error calculated as in OOB_PerTreeError
1043  * (Ulli's version)
1044  */
1046 
1047  /**Column containing the development of the Ensemble
1048  * error rate with increasing number of trees
1049  */
1051  /** 4 dimensional array containing the development of confusion matrices
1052  * with number of trees - can be used to estimate ROC curves etc.
1053  *
1054  * oobroc_per_tree(ii,jj,kk,ll)
1055  * corresponds true label = ii
1056  * predicted label = jj
1057  * confusion matrix after ll trees
1058  *
1059  * explanation of third index:
1060  *
1061  * Two class case:
1062  * kk = 0 - (treeCount-1)
1063  * Threshold is on Probability for class 0 is kk/(treeCount-1);
1064  * More classes:
1065  * kk = 0. Threshold on probability set by argMax of the probability array.
1066  */
1068 
1070 
1071 #ifdef HasHDF5
1072  /** save to HDF5 file
1073  */
1074  void save(std::string filen, std::string pathn)
1075  {
1076  if(*(pathn.end()-1) != '/')
1077  pathn += "/";
1078  const char* filename = filen.c_str();
1079  MultiArray<2, double> temp(Shp(1,1), 0.0);
1080  writeHDF5(filename, (pathn + "oob_per_tree").c_str(), oob_per_tree);
1081  writeHDF5(filename, (pathn + "oobroc_per_tree").c_str(), oobroc_per_tree);
1082  writeHDF5(filename, (pathn + "breiman_per_tree").c_str(), breiman_per_tree);
1083  temp[0] = oob_mean;
1084  writeHDF5(filename, (pathn + "per_tree_error").c_str(), temp);
1085  temp[0] = oob_std;
1086  writeHDF5(filename, (pathn + "per_tree_error_std").c_str(), temp);
1087  temp[0] = oob_breiman;
1088  writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
1089  temp[0] = oob_per_tree2;
1090  writeHDF5(filename, (pathn + "ulli_error").c_str(), temp);
1091  }
1092 #endif
1093  // negative value if sample was ib, number indicates how often.
1094  // value >=0 if sample was oob, 0 means fail 1, correct
1095 
1096  template<class RF, class PR>
1097  void visit_at_beginning(RF & rf, PR &)
1098  {
1099  class_count = rf.class_count();
1100  if(class_count == 2)
1101  oobroc_per_tree.reshape(MultiArrayShape<4>::type(2,2,rf.tree_count(), rf.tree_count()));
1102  else
1103  oobroc_per_tree.reshape(MultiArrayShape<4>::type(rf.class_count(),rf.class_count(),1, rf.tree_count()));
1104  tmp_prob.reshape(Shp(1, class_count), 0);
1105  prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
1106  is_weighted = rf.options().predict_weighted_;
1107  oob_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1108  breiman_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1109  //do the first time called.
1110  if(int(oobCount.size()) != rf.ext_param_.row_count_)
1111  {
1112  oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
1113  oobErrorCount.reshape(Shp(rf.ext_param_.row_count_,1), 0);
1114  }
1115  }
1116 
1117  template<class RF, class PR, class SM, class ST>
1118  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST &, int index)
1119  {
1120  // go through the samples
1121  int total_oob =0;
1122  int wrong_oob =0;
1123  for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
1124  {
1125  // if the lth sample is oob...
1126  if(!sm.is_used()[ll])
1127  {
1128  // update number of trees in which current sample is oob
1129  ++oobCount[ll];
1130 
1131  // update number of oob samples in this tree.
1132  ++total_oob;
1133  // get the predicted votes ---> tmp_prob;
1134  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
1135  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
1136  rf.tree(index).parameters_,
1137  pos);
1138  tmp_prob.init(0);
1139  for(int ii = 0; ii < class_count; ++ii)
1140  {
1141  tmp_prob[ii] = node.prob_begin()[ii];
1142  }
1143  if(is_weighted)
1144  {
1145  for(int ii = 0; ii < class_count; ++ii)
1146  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
1147  }
1148  rowVector(prob_oob, ll) += tmp_prob;
1149  int label = argMax(tmp_prob);
1150 
1151  if(label != pr.response()(ll, 0))
1152  {
1153  // update number of wrong oob samples in this tree.
1154  ++wrong_oob;
1155  // update number of trees in which current sample is wrong oob
1156  ++oobErrorCount[ll];
1157  }
1158  }
1159  }
1160  int breimanstyle = 0;
1161  int totalOobCount = 0;
1162  for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1163  {
1164  if(oobCount[ll])
1165  {
1166  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1167  ++breimanstyle;
1168  ++totalOobCount;
1169  if(oobroc_per_tree.shape(2) == 1)
1170  {
1171  oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++;
1172  }
1173  }
1174  }
1175  if(oobroc_per_tree.shape(2) == 1)
1176  oobroc_per_tree.bindOuter(index)/=totalOobCount;
1177  if(oobroc_per_tree.shape(2) > 1)
1178  {
1179  MultiArrayView<3, double> current_roc
1180  = oobroc_per_tree.bindOuter(index);
1181  for(int gg = 0; gg < current_roc.shape(2); ++gg)
1182  {
1183  for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1184  {
1185  if(oobCount[ll])
1186  {
1187  int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
1188  1 : 0;
1189  current_roc(pr.response()(ll, 0), pred, gg)+= 1;
1190  }
1191  }
1192  current_roc.bindOuter(gg)/= totalOobCount;
1193  }
1194  }
1195  breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount);
1196  oob_per_tree[index] = double(wrong_oob)/double(total_oob);
1197  // go through the ib samples;
1198  }
1199 
1200  /** Normalise variable importance after the number of trees is known.
1201  */
1202  template<class RF, class PR>
1203  void visit_at_end(RF & rf, PR & pr)
1204  {
1205  // ullis original metric and breiman style stuff
1206  oob_per_tree2 = 0;
1207  int totalOobCount =0;
1208  int breimanstyle = 0;
1209  for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1210  {
1211  if(oobCount[ll])
1212  {
1213  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1214  ++breimanstyle;
1215  oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll];
1216  ++totalOobCount;
1217  }
1218  }
1219  oob_per_tree2 /= totalOobCount;
1220  oob_breiman = double(breimanstyle)/totalOobCount;
1221  // mean error of each tree
1222  MultiArrayView<2, double> mean(Shp(1,1), &oob_mean);
1223  MultiArrayView<2, double> stdDev(Shp(1,1), &oob_std);
1224  rowStatistics(oob_per_tree, mean, stdDev);
1225  }
1226 };
1227 
1228 /** calculate variable importance while learning.
1229  */
1231 {
1232  public:
1233 
1234  /** This Array has the same entries as the R - random forest variable
1235  * importance.
1236  * Matrix is featureCount by (classCount +2)
1237  * variable_importance_(ii,jj) is the variable importance measure of
1238  * the ii-th variable according to:
1239  * jj = 0 - (classCount-1)
1240  * classwise permutation importance
1241  * jj = rowCount(variable_importance_) -2
1242  * permutation importance
1243  * jj = rowCount(variable_importance_) -1
1244  * gini decrease importance.
1245  *
1246  * permutation importance:
1247  * The difference between the fraction of OOB samples classified correctly
1248  * before and after permuting (randomizing) the ii-th column is calculated.
1249  * The ii-th column is permuted rep_cnt times.
1250  *
1251  * class wise permutation importance:
1252  * same as permutation importance. We only look at those OOB samples whose
1253  * response corresponds to class jj.
1254  *
1255  * gini decrease importance:
1256  * row ii corresponds to the sum of all gini decreases induced by variable ii
1257  * in each node of the random forest.
1258  */
1260  int repetition_count_;
1261  bool in_place_;
1262 
1263 #ifdef HasHDF5
1264  void save(std::string filename, std::string prefix)
1265  {
1266  prefix = "variable_importance_" + prefix;
1267  writeHDF5(filename.c_str(),
1268  prefix.c_str(),
1270  }
1271 #endif
1272 
1273  /* Constructor
1274  * \param rep_cnt (defautl: 10) how often should
1275  * the permutation take place. Set to 1 to make calculation faster (but
1276  * possibly more instable)
1277  */
1278  VariableImportanceVisitor(int rep_cnt = 10)
1279  : repetition_count_(rep_cnt)
1280 
1281  {}
1282 
1283  /** calculates impurity decrease based variable importance after every
1284  * split.
1285  */
1286  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1287  void visit_after_split( Tree & tree,
1288  Split & split,
1289  Region & /* parent */,
1290  Region & /* leftChild */,
1291  Region & /* rightChild */,
1292  Feature_t & /* features */,
1293  Label_t & /* labels */)
1294  {
1295  //resize to right size when called the first time
1296 
1297  Int32 const class_count = tree.ext_param_.class_count_;
1298  Int32 const column_count = tree.ext_param_.column_count_;
1299  if(variable_importance_.size() == 0)
1300  {
1301 
1303  .reshape(MultiArrayShape<2>::type(column_count,
1304  class_count+2));
1305  }
1306 
1307  if(split.createNode().typeID() == i_ThresholdNode)
1308  {
1309  Node<i_ThresholdNode> node(split.createNode());
1310  variable_importance_(node.column(),class_count+1)
1311  += split.region_gini_ - split.minGini();
1312  }
1313  }
1314 
1315  /**compute permutation based var imp.
1316  * (Only an Array of size oob_sample_count x 1 is created.
1317  * - apposed to oob_sample_count x feature_count in the other method.
1318  *
1319  * \sa FieldProxy
1320  */
1321  template<class RF, class PR, class SM, class ST>
1322  void after_tree_ip_impl(RF& rf, PR & pr, SM & sm, ST & /* st */, int index)
1323  {
1324  typedef MultiArrayShape<2>::type Shp_t;
1325  Int32 column_count = rf.ext_param_.column_count_;
1326  Int32 class_count = rf.ext_param_.class_count_;
1327 
1328  /* This solution saves memory uptake but not multithreading
1329  * compatible
1330  */
1331  // remove the const cast on the features (yep , I know what I am
1332  // doing here.) data is not destroyed.
1333  //typename PR::Feature_t & features
1334  // = const_cast<typename PR::Feature_t &>(pr.features());
1335 
1336  typedef typename PR::FeatureWithMemory_t FeatureArray;
1337  typedef typename FeatureArray::value_type FeatureValue;
1338 
1339  FeatureArray features = pr.features();
1340 
1341  //find the oob indices of current tree.
1342  ArrayVector<Int32> oob_indices;
1344  iter;
1345  for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1346  if(!sm.is_used()[ii])
1347  oob_indices.push_back(ii);
1348 
1349  //create space to back up a column
1350  ArrayVector<FeatureValue> backup_column;
1351 
1352  // Random foo
1353 #ifdef CLASSIFIER_TEST
1354  RandomMT19937 random(1);
1355 #else
1356  RandomMT19937 random(RandomSeed);
1357 #endif
1359  randint(random);
1360 
1361 
1362  //make some space for the results
1364  oob_right(Shp_t(1, class_count + 1));
1366  perm_oob_right (Shp_t(1, class_count + 1));
1367 
1368 
1369  // get the oob success rate with the original samples
1370  for(iter = oob_indices.begin();
1371  iter != oob_indices.end();
1372  ++iter)
1373  {
1374  if(rf.tree(index)
1375  .predictLabel(rowVector(features, *iter))
1376  == pr.response()(*iter, 0))
1377  {
1378  //per class
1379  ++oob_right[pr.response()(*iter,0)];
1380  //total
1381  ++oob_right[class_count];
1382  }
1383  }
1384  //get the oob rate after permuting the ii'th dimension.
1385  for(int ii = 0; ii < column_count; ++ii)
1386  {
1387  perm_oob_right.init(0.0);
1388  //make backup of original column
1389  backup_column.clear();
1390  for(iter = oob_indices.begin();
1391  iter != oob_indices.end();
1392  ++iter)
1393  {
1394  backup_column.push_back(features(*iter,ii));
1395  }
1396 
1397  //get the oob rate after permuting the ii'th dimension.
1398  for(int rr = 0; rr < repetition_count_; ++rr)
1399  {
1400  //permute dimension.
1401  int n = oob_indices.size();
1402  for(int jj = n-1; jj >= 1; --jj)
1403  std::swap(features(oob_indices[jj], ii),
1404  features(oob_indices[randint(jj+1)], ii));
1405 
1406  //get the oob success rate after permuting
1407  for(iter = oob_indices.begin();
1408  iter != oob_indices.end();
1409  ++iter)
1410  {
1411  if(rf.tree(index)
1412  .predictLabel(rowVector(features, *iter))
1413  == pr.response()(*iter, 0))
1414  {
1415  //per class
1416  ++perm_oob_right[pr.response()(*iter, 0)];
1417  //total
1418  ++perm_oob_right[class_count];
1419  }
1420  }
1421  }
1422 
1423 
1424  //normalise and add to the variable_importance array.
1425  perm_oob_right /= repetition_count_;
1426  perm_oob_right -=oob_right;
1427  perm_oob_right *= -1;
1428  perm_oob_right /= oob_indices.size();
1430  .subarray(Shp_t(ii,0),
1431  Shp_t(ii+1,class_count+1)) += perm_oob_right;
1432  //copy back permuted dimension
1433  for(int jj = 0; jj < int(oob_indices.size()); ++jj)
1434  features(oob_indices[jj], ii) = backup_column[jj];
1435  }
1436  }
1437 
1438  /** calculate permutation based impurity after every tree has been
1439  * learned default behaviour is that this happens out of place.
1440  * If you have very big data sets and want to avoid copying of data
1441  * set the in_place_ flag to true.
1442  */
1443  template<class RF, class PR, class SM, class ST>
1444  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
1445  {
1446  after_tree_ip_impl(rf, pr, sm, st, index);
1447  }
1448 
1449  /** Normalise variable importance after the number of trees is known.
1450  */
1451  template<class RF, class PR>
1452  void visit_at_end(RF & rf, PR & /* pr */)
1453  {
1454  variable_importance_ /= rf.trees_.size();
1455  }
1456 };
1457 
1458 /** Verbose output
1459  */
1461  public:
1463 
1464  template<class RF, class PR, class SM, class ST>
1465  void visit_after_tree(RF& rf, PR &, SM &, ST &, int index){
1466  if(index != rf.options().tree_count_-1) {
1467  std::cout << "\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << "%]"
1468  << " (" << index+1 << " of " << rf.options().tree_count_ << ") done" << std::flush;
1469  }
1470  else {
1471  std::cout << "\r[" << std::setw(10) << 100.0 << "%]" << std::endl;
1472  }
1473  }
1474 
1475  template<class RF, class PR>
1476  void visit_at_end(RF const & rf, PR const &) {
1477  std::string a = TOCS;
1478  std::cout << "all " << rf.options().tree_count_ << " trees have been learned in " << a << std::endl;
1479  }
1480 
1481  template<class RF, class PR>
1482  void visit_at_beginning(RF const & rf, PR const &) {
1483  TIC;
1484  std::cout << "growing random forest, which will have " << rf.options().tree_count_ << " trees" << std::endl;
1485  }
1486 
1487  private:
1488  USETICTOC;
1489 };
1490 
1491 
1492 /** Computes Correlation/Similarity Matrix of features while learning
1493  * random forest.
1494  */
1496 {
1497  public:
1498  /** gini_missc(ii, jj) describes how well variable jj can describe a partition
1499  * created on variable ii(when variable ii was chosen)
1500  */
1502  MultiArray<2, int> tmp_labels;
1503  /** additional noise features.
1504  */
1506  MultiArray<2, double> noise_l;
1507  /** how well can a noise column describe a partition created on variable ii.
1508  */
1510  MultiArray<2, double> corr_l;
1511 
1512  /** Similarity Matrix
1513  *
1514  * (numberOfFeatures + 1) by (number Of Features + 1) Matrix
1515  * gini_missc
1516  * - row normalized by the number of times the column was chosen
1517  * - mean of corr_noise subtracted
1518  * - and symmetrised.
1519  *
1520  */
1522  /** Distance Matrix 1-similarity
1523  */
1525  ArrayVector<int> tmp_cc;
1526 
1527  /** How often was variable ii chosen
1528  */
1532  void save(std::string, std::string)
1533  {
1534  /*
1535  std::string tmp;
1536 #define VAR_WRITE(NAME) \
1537  tmp = #NAME;\
1538  tmp += "_";\
1539  tmp += prefix;\
1540  vigra::writeToHDF5File(file.c_str(), tmp.c_str(), NAME);
1541  VAR_WRITE(gini_missc);
1542  VAR_WRITE(corr_noise);
1543  VAR_WRITE(distance);
1544  VAR_WRITE(similarity);
1545  vigra::writeToHDF5File(file.c_str(), "nChoices", MultiArrayView<2, int>(MultiArrayShape<2>::type(numChoices.size(),1), numChoices.data()));
1546 #undef VAR_WRITE
1547 */
1548  }
1549 
1550  template<class RF, class PR>
1551  void visit_at_beginning(RF const & rf, PR & pr)
1552  {
1553  typedef MultiArrayShape<2>::type Shp;
1554  int n = rf.ext_param_.column_count_;
1555  gini_missc.reshape(Shp(n +1,n+ 1));
1556  corr_noise.reshape(Shp(n + 1, 10));
1557  corr_l.reshape(Shp(n +1, 10));
1558 
1559  noise.reshape(Shp(pr.features().shape(0), 10));
1560  noise_l.reshape(Shp(pr.features().shape(0), 10));
1561  RandomMT19937 random(RandomSeed);
1562  for(int ii = 0; ii < noise.size(); ++ii)
1563  {
1564  noise[ii] = random.uniform53();
1565  noise_l[ii] = random.uniform53() > 0.5;
1566  }
1567  bgfunc = ColumnDecisionFunctor( rf.ext_param_);
1568  tmp_labels.reshape(pr.response().shape());
1569  tmp_cc.resize(2);
1570  numChoices.resize(n+1);
1571  // look at all axes
1572  }
1573  template<class RF, class PR>
1574  void visit_at_end(RF const &, PR const &)
1575  {
1576  typedef MultiArrayShape<2>::type Shp;
1579  MultiArray<2, double> mean_noise(Shp(corr_noise.shape(0), 1));
1580  rowStatistics(corr_noise, mean_noise);
1581  mean_noise/= MultiArrayView<2, int>(mean_noise.shape(), numChoices.data());
1582  int rC = similarity.shape(0);
1583  for(int jj = 0; jj < rC-1; ++jj)
1584  {
1585  rowVector(similarity, jj) /= numChoices[jj];
1586  rowVector(similarity, jj) -= mean_noise(jj, 0);
1587  }
1588  for(int jj = 0; jj < rC; ++jj)
1589  {
1590  similarity(rC -1, jj) /= numChoices[jj];
1591  }
1592  rowVector(similarity, rC - 1) -= mean_noise(rC-1, 0);
1594  FindMinMax<double> minmax;
1595  inspectMultiArray(srcMultiArrayRange(similarity), minmax);
1596 
1597  for(int jj = 0; jj < rC; ++jj)
1598  similarity(jj, jj) = minmax.max;
1599 
1600  similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))
1601  += similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose();
1602  similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2;
1603  columnVector(similarity, rC-1) = rowVector(similarity, rC-1).transpose();
1604  for(int jj = 0; jj < rC; ++jj)
1605  similarity(jj, jj) = 0;
1606 
1607  FindMinMax<double> minmax2;
1608  inspectMultiArray(srcMultiArrayRange(similarity), minmax2);
1609  for(int jj = 0; jj < rC; ++jj)
1610  similarity(jj, jj) = minmax2.max;
1611  distance.reshape(gini_missc.shape(), minmax2.max);
1612  distance -= similarity;
1613  }
1614 
1615  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1616  void visit_after_split( Tree &,
1617  Split & split,
1618  Region & parent,
1619  Region &,
1620  Region &,
1621  Feature_t & features,
1622  Label_t & labels)
1623  {
1624  if(split.createNode().typeID() == i_ThresholdNode)
1625  {
1626  double wgini;
1627  tmp_cc.init(0);
1628  for(int ii = 0; ii < parent.size(); ++ii)
1629  {
1630  tmp_labels[parent[ii]]
1631  = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
1632  ++tmp_cc[tmp_labels[parent[ii]]];
1633  }
1634  double region_gini = bgfunc.loss_of_region(tmp_labels,
1635  parent.begin(),
1636  parent.end(),
1637  tmp_cc);
1638 
1639  int n = split.bestSplitColumn();
1640  ++numChoices[n];
1641  ++(*(numChoices.end()-1));
1642  //this functor does all the work
1643  for(int k = 0; k < features.shape(1); ++k)
1644  {
1645  bgfunc(columnVector(features, k),
1646  tmp_labels,
1647  parent.begin(), parent.end(),
1648  tmp_cc);
1649  wgini = (region_gini - bgfunc.min_gini_);
1650  gini_missc(n, k)
1651  += wgini;
1652  }
1653  for(int k = 0; k < 10; ++k)
1654  {
1655  bgfunc(columnVector(noise, k),
1656  tmp_labels,
1657  parent.begin(), parent.end(),
1658  tmp_cc);
1659  wgini = (region_gini - bgfunc.min_gini_);
1660  corr_noise(n, k)
1661  += wgini;
1662  }
1663 
1664  for(int k = 0; k < 10; ++k)
1665  {
1666  bgfunc(columnVector(noise_l, k),
1667  tmp_labels,
1668  parent.begin(), parent.end(),
1669  tmp_cc);
1670  wgini = (region_gini - bgfunc.min_gini_);
1671  corr_l(n, k)
1672  += wgini;
1673  }
1674  bgfunc(labels, tmp_labels, parent.begin(), parent.end(),tmp_cc);
1675  wgini = (region_gini - bgfunc.min_gini_);
1677  += wgini;
1678 
1679  region_gini = split.region_gini_;
1680 #if 1
1681  Node<i_ThresholdNode> node(split.createNode());
1683  node.column())
1684  +=split.region_gini_ - split.minGini();
1685 #endif
1686  for(int k = 0; k < 10; ++k)
1687  {
1688  split.bgfunc(columnVector(noise, k),
1689  labels,
1690  parent.begin(), parent.end(),
1691  parent.classCounts());
1693  k)
1694  += wgini;
1695  }
1696 #if 0
1697  for(int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
1698  {
1699  wgini = region_gini - split.min_gini_[k];
1700 
1702  split.splitColumns[k])
1703  += wgini;
1704  }
1705 
1706  for(int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
1707  {
1708  split.bgfunc(columnVector(features, split.splitColumns[k]),
1709  labels,
1710  parent.begin(), parent.end(),
1711  parent.classCounts());
1712  wgini = region_gini - split.bgfunc.min_gini_;
1714  split.splitColumns[k]) += wgini;
1715  }
1716 #endif
1717  // remember to partition the data according to the best.
1720  += region_gini;
1721  SortSamplesByDimensions<Feature_t>
1722  sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
1723  std::partition(parent.begin(), parent.end(), sorter);
1724  }
1725  }
1726 };
1727 
1728 
1729 } // namespace visitors
1730 } // namespace rf
1731 } // namespace vigra
1732 
1733 #endif // RF_VISITORS_HXX
#define TIC
Definition: timing.hxx:322
void visit_at_end(RF &rf, PR &pr)
Definition: rf_visitors.hxx:994
MultiArray< 2, double > oob_per_tree
Definition: rf_visitors.hxx:1025
void visit_at_beginning(RF &rf, const PR &)
Definition: rf_visitors.hxx:627
MultiArray< 2, double > noise
Definition: rf_visitors.hxx:1505
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:727
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:671
MultiArray< 2, double > variable_importance_
Definition: rf_visitors.hxx:1259
const difference_type & shape() const
Definition: multi_array.hxx:1648
MultiArray< 2, double > breiman_per_tree
Definition: rf_visitors.hxx:1050
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition: rf_visitors.hxx:1322
MultiArray< 2, double > corr_noise
Definition: rf_visitors.hxx:1509
void visit_after_tree(RF &, PR &, SM &, ST &, int)
Definition: rf_visitors.hxx:646
const_iterator begin() const
Definition: array_vector.hxx:223
MultiArray< 2, double > similarity
Definition: rf_visitors.hxx:1521
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2861
Definition: rf_visitors.hxx:863
MultiArray< 4, double > oobroc_per_tree
Definition: rf_visitors.hxx:1067
Definition: rf_visitors.hxx:1495
ArrayVector< int > numChoices
Definition: rf_visitors.hxx:1529
Definition: rf_visitors.hxx:1230
MultiArrayView< N, T, StridedArrayTag > transpose() const
Definition: multi_array.hxx:1567
double oob_breiman
Definition: rf_visitors.hxx:1038
Definition: random.hxx:669
double oob_mean
Definition: rf_visitors.hxx:1028
MultiArray< 2, double > gini_missc
Definition: rf_visitors.hxx:1501
double return_val()
Definition: rf_visitors.hxx:225
difference_type_1 size() const
Definition: multi_array.hxx:1641
MultiArray< 2, double > distance
Definition: rf_visitors.hxx:1524
Definition: multi_fwd.hxx:63
void reset_tree(int tree_id)
Definition: rf_visitors.hxx:635
double oobError
Definition: rf_visitors.hxx:787
Definition: rf_visitors.hxx:254
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
void visit_internal_node(TR &, IntT, TopT, Feat &)
Definition: rf_visitors.hxx:215
void init(U const &initial)
Definition: array_vector.hxx:146
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition: rf_visitors.hxx:806
double oob_per_tree2
Definition: rf_visitors.hxx:1045
Definition: rf_split.hxx:831
MultiArray & init(const U &init)
Definition: multi_array.hxx:2851
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
void visit_at_end(RF &rf, PR &)
Definition: rf_visitors.hxx:1452
Definition: rf_visitors.hxx:1015
Definition: rf_visitors.hxx:583
double oob_breiman
Definition: rf_visitors.hxx:874
#define TOCS
Definition: timing.hxx:325
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:940
void writeHDF5(...)
Store array data in an HDF5 file.
Definition: rf_visitors.hxx:1460
double oob_std
Definition: rf_visitors.hxx:1031
void visit_internal_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition: rf_visitors.hxx:723
void visit_at_end(RF const &rf, PR const &pr)
Definition: rf_visitors.hxx:175
void visit_at_end(RF &rf, PR &pr)
Definition: rf_visitors.hxx:1203
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:697
Definition: rf_visitors.hxx:101
void visit_at_end(RF &rf, PR &)
Definition: rf_visitors.hxx:835
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:684
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude)
Definition: fftw3.hxx:1002
void visit_at_beginning(RF const &rf, PR const &pr)
Definition: rf_visitors.hxx:187
Definition: random.hxx:336
void visit_after_split(Tree &tree, Split &split, Region &parent, Region &leftChild, Region &rightChild, Feature_t &features, Label_t &labels)
Definition: rf_visitors.hxx:142
const_iterator end() const
Definition: array_vector.hxx:237
const_pointer data() const
Definition: array_vector.hxx:209
size_type size() const
Definition: array_vector.hxx:358
MultiArrayView subarray(difference_type p, difference_type q) const
Definition: multi_array.hxx:1528
void inspectMultiArray(...)
Call an analyzing functor at every element of a multi-dimensional array.
void visit_after_split(Tree &tree, Split &split, Region &, Region &, Region &, Feature_t &, Label_t &)
Definition: rf_visitors.hxx:1287
void visit_external_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition: rf_visitors.hxx:205
Definition: rf_visitors.hxx:782
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_visitors.hxx:163
void rowStatistics(...)
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_visitors.hxx:1444
MultiArrayView< N-M, T, StrideTag > bindOuter(const TinyVector< Index, M > &d) const
Definition: multi_array.hxx:2184
Definition: rf_visitors.hxx:234
detail::VisitorNode< A > create_visitor(A &a)
Definition: rf_visitors.hxx:344

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.1 (Fri May 19 2017)