[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_preprocessing.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 #ifndef VIGRA_RF_PREPROCESSING_HXX
37 #define VIGRA_RF_PREPROCESSING_HXX
38 
39 #include <limits>
40 #include <vigra/mathutil.hxx>
41 #include "rf_common.hxx"
42 
43 namespace vigra
44 {
45 
46 /** Class used while preprocessing (currently used only during learn)
47  *
48  * This class is internally used by the Random Forest learn function.
49  * Different split functors may need to process the data in different manners
50  * (i.e., regression labels that should not be touched and classification
51  * labels that must be converted into a integral format)
52  *
53  * This Class only exists in specialized versions, where the Tag class is
54  * fixed.
55  *
56  * The Tag class is determined by Splitfunctor::Preprocessor_t . Currently
57  * it can either be ClassificationTag or RegressionTag. look At the
58  * RegressionTag specialisation for the basic interface if you ever happen
59  * to care.... - or need some sort of vague new preprocessor.
60  * new preprocessor ( Soft labels or whatever)
61  */
62 template<class Tag, class LabelType, class T1, class C1, class T2, class C2>
63 class Processor;
64 
65 namespace detail
66 {
67 
68  /* Common helper function used in all Processors.
69  * This function analyses the options struct and calculates the real
70  * values needed for the current problem (data)
71  */
72  template<class T>
73  void fill_external_parameters(RandomForestOptions const & options,
74  ProblemSpec<T> & ext_param)
75  {
76  // set correct value for mtry
77  switch(options.mtry_switch_)
78  {
79  case RF_SQRT:
80  ext_param.actual_mtry_ =
81  int(std::floor(
82  std::sqrt(double(ext_param.column_count_))
83  + 0.5));
84  break;
85  case RF_LOG:
86  // this is in Breimans original paper
87  ext_param.actual_mtry_ =
88  int(1+(std::log(double(ext_param.column_count_))
89  /std::log(2.0)));
90  break;
91  case RF_FUNCTION:
92  ext_param.actual_mtry_ =
93  options.mtry_func_(ext_param.column_count_);
94  break;
95  case RF_ALL:
96  ext_param.actual_mtry_ = ext_param.column_count_;
97  break;
98  default:
99  ext_param.actual_mtry_ =
100  options.mtry_;
101  }
102  // set correct value for msample
103  switch(options.training_set_calc_switch_)
104  {
105  case RF_CONST:
106  ext_param.actual_msample_ =
107  options.training_set_size_;
108  break;
109  case RF_PROPORTIONAL:
110  ext_param.actual_msample_ =
111  static_cast<int>(std::ceil(options.training_set_proportion_ *
112  ext_param.row_count_));
113  break;
114  case RF_FUNCTION:
115  ext_param.actual_msample_ =
116  options.training_set_func_(ext_param.row_count_);
117  break;
118  default:
119  vigra_precondition(1!= 1, "unexpected error");
120 
121  }
122 
123  }
124 
125  /* Returns true if MultiArray contains NaNs
126  */
127  template<unsigned int N, class T, class C>
128  bool contains_nan(MultiArrayView<N, T, C> const & in)
129  {
130  typedef typename MultiArrayView<N, T, C>::const_iterator Iter;
131  Iter i = in.begin(), end = in.end();
132  for(; i != end; ++i)
133  if(isnan(NumericTraits<T>::toRealPromote(*i)))
134  return true;
135  return false;
136  }
137 
138  /* Returns true if MultiArray contains Infs
139  */
140  template<unsigned int N, class T, class C>
141  bool contains_inf(MultiArrayView<N, T, C> const & in)
142  {
143  if(!std::numeric_limits<T>::has_infinity)
144  return false;
145  typedef typename MultiArrayView<N, T, C>::const_iterator Iter;
146  Iter i = in.begin(), end = in.end();
147  for(; i != end; ++i)
148  if(abs(*i) == std::numeric_limits<T>::infinity())
149  return true;
150  return false;
151  }
152 } // namespace detail
153 
154 
155 
156 /** Preprocessor used during Classification
157  *
158  * This class converts the labels int Integral labels which are used by the
159  * standard split functor to address memory in the node objects.
160  */
161 template<class LabelType, class T1, class C1, class T2, class C2>
162 class Processor<ClassificationTag, LabelType, T1, C1, T2, C2>
163 {
164  public:
165  typedef Int32 LabelInt;
169  MultiArrayView<2, T1, C1>const & features_;
170  MultiArray<2, LabelInt> intLabels_;
172 
173  template<class T>
174  Processor(MultiArrayView<2, T1, C1>const & features,
175  MultiArrayView<2, T2, C2>const & response,
176  RandomForestOptions &options,
177  ProblemSpec<T> &ext_param)
178  :
179  features_( features) // do not touch the features.
180  {
181  vigra_precondition(!detail::contains_nan(features), "RandomForest(): Feature matrix "
182  "contains NaNs");
183  vigra_precondition(!detail::contains_nan(response), "RandomForest(): Response "
184  "contains NaNs");
185  vigra_precondition(!detail::contains_inf(features), "RandomForest(): Feature matrix "
186  "contains inf");
187  vigra_precondition(!detail::contains_inf(response), "RandomForest(): Response "
188  "contains inf");
189  // set some of the problem specific parameters
190  ext_param.column_count_ = features.shape(1);
191  ext_param.row_count_ = features.shape(0);
192  ext_param.problem_type_ = CLASSIFICATION;
193  ext_param.used_ = true;
194  intLabels_.reshape(response.shape());
195 
196  //get the class labels
197  if(ext_param.class_count_ == 0)
198  {
199  // fill up a map with the current labels and then create the
200  // integral labels.
201  std::set<T2> labelToInt;
202  for(MultiArrayIndex k = 0; k < features.shape(0); ++k)
203  labelToInt.insert(response(k,0));
204  std::vector<T2> tmp_(labelToInt.begin(), labelToInt.end());
205  ext_param.classes_(tmp_.begin(), tmp_.end());
206  }
207  for(MultiArrayIndex k = 0; k < features.shape(0); ++k)
208  {
209  if(std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0)) == ext_param.classes.end())
210  {
211  throw std::runtime_error("RandomForest(): invalid label in training data.");
212  }
213  else
214  intLabels_(k, 0) = std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0))
215  - ext_param.classes.begin();
216  }
217  // set class weights
218  if(ext_param.class_weights_.size() == 0)
219  {
221  tmp(static_cast<std::size_t>(ext_param.class_count_),
222  NumericTraits<T2>::one());
223  ext_param.class_weights(tmp.begin(), tmp.end());
224  }
225 
226  // set mtry and msample
227  detail::fill_external_parameters(options, ext_param);
228 
229  // set strata
230  strata_ = intLabels_;
231 
232  }
233 
234  /** Access the processed features
235  */
237  {
238  return features_;
239  }
240 
241  /** Access processed labels
242  */
244  {
245  return MultiArrayView<2, LabelInt>(intLabels_);
246  }
247 
248  /** Access processed strata
249  */
251  {
252  return ArrayVectorView<LabelInt>(intLabels_.size(), intLabels_.data());
253  }
254 
255  /** Access strata fraction sized - not used currently
256  */
258  {
259  return ArrayVectorView< double>();
260  }
261 };
262 
263 
264 
265 /** Regression Preprocessor - This basically does not do anything with the
266  * data.
267  */
268 template<class LabelType, class T1, class C1, class T2, class C2>
269 class Processor<RegressionTag,LabelType, T1, C1, T2, C2>
270 {
271 public:
272  // only views are created - no data copied.
273  MultiArrayView<2, T1, C1> features_;
274  MultiArrayView<2, T2, C2> response_;
275  RandomForestOptions const & options_;
276  ProblemSpec<LabelType> const &
277  ext_param_;
278  // will only be filled if needed
279  MultiArray<2, int> strata_;
280  bool strata_filled;
281 
282  // copy the views.
283  template<class T>
285  MultiArrayView<2, T2, C2> response,
286  RandomForestOptions const & options,
287  ProblemSpec<T>& ext_param)
288  :
289  features_(features),
290  response_(response),
291  options_(options),
292  ext_param_(ext_param)
293  {
294  // set some of the problem specific parameters
295  ext_param.column_count_ = features.shape(1);
296  ext_param.row_count_ = features.shape(0);
297  ext_param.problem_type_ = REGRESSION;
298  ext_param.used_ = true;
299  detail::fill_external_parameters(options, ext_param);
300  vigra_precondition(!detail::contains_nan(features), "Processor(): Feature Matrix "
301  "Contains NaNs");
302  vigra_precondition(!detail::contains_nan(response), "Processor(): Response "
303  "Contains NaNs");
304  vigra_precondition(!detail::contains_inf(features), "Processor(): Feature Matrix "
305  "Contains inf");
306  vigra_precondition(!detail::contains_inf(response), "Processor(): Response "
307  "Contains inf");
308  strata_ = MultiArray<2, int> (MultiArrayShape<2>::type(response_.shape(0), 1));
309  ext_param.response_size_ = response.shape(1);
310  ext_param.class_count_ = response_.shape(1);
311  std::vector<T2> tmp_(ext_param.class_count_, 0);
312  ext_param.classes_(tmp_.begin(), tmp_.end());
313  }
314 
315  /** access preprocessed features
316  */
318  {
319  return features_;
320  }
321 
322  /** access preprocessed response
323  */
325  {
326  return response_;
327  }
328 
329  /** access strata - this is not used currently
330  */
332  {
333  return strata_;
334  }
335 };
336 }
337 #endif //VIGRA_RF_PREPROCESSING_HXX
338 
339 
340 
ArrayVectorView< double > strata_prob()
Definition: rf_preprocessing.hxx:257
MultiArrayView< 2, LabelInt > response()
Definition: rf_preprocessing.hxx:243
Definition: rf_preprocessing.hxx:63
const difference_type & shape() const
Definition: multi_array.hxx:1648
Definition: array_vector.hxx:76
const_iterator begin() const
Definition: array_vector.hxx:223
problem specification class for the random forest.
Definition: rf_common.hxx:538
MultiArrayView< 2, T1, C1 > & features()
Definition: rf_preprocessing.hxx:317
Main MultiArray class containing the memory management.
Definition: multi_array.hxx:2474
std::ptrdiff_t MultiArrayIndex
Definition: multi_fwd.hxx:60
MultiArrayView< 2, T1, C1 > const & features()
Definition: rf_preprocessing.hxx:236
ProblemSpec & classes_(C_Iter begin, C_Iter end)
supply with class labels -
Definition: rf_common.hxx:828
Definition: array_vector.hxx:58
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
MultiArrayView< 2, T2, C2 > & response()
Definition: rf_preprocessing.hxx:324
TinyVector< MultiArrayIndex, N > type
Definition: multi_shape.hxx:272
ProblemSpec & class_weights(W_Iter begin, W_Iter end)
supply with class weights -
Definition: rf_common.hxx:844
linalg::TemporaryMatrix< T > log(MultiArrayView< 2, T, C > const &v)
MultiArray< 2, int > & strata()
Definition: rf_preprocessing.hxx:331
ArrayVectorView< LabelInt > strata()
Definition: rf_preprocessing.hxx:250
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude)
Definition: fftw3.hxx:1002
Options object for the random forest.
Definition: rf_common.hxx:170
const_iterator end() const
Definition: array_vector.hxx:237
size_type size() const
Definition: array_vector.hxx:358
int ceil(FixedPoint< IntBits, FracBits > v)
rounding up.
Definition: fixedpoint.hxx:675
int floor(FixedPoint< IntBits, FracBits > v)
rounding down.
Definition: fixedpoint.hxx:667
SquareRootTraits< FixedPoint< IntBits, FracBits > >::SquareRootResult sqrt(FixedPoint< IntBits, FracBits > v)
square root.
Definition: fixedpoint.hxx:616

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.1 (Fri May 19 2017)