00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019 #ifndef OST_QA_MULTI_CLASSIFIER_HH
00020 #define OST_QA_MULTI_CLASSIFIER_HH
00021
00022 #include <ost/stdint.hh>
00023 #include <vector>
00024 #include <cassert>
00025 #include <fstream>
00026 #include <ost/message.hh>
00027 #include <iostream>
00028
00029 #include <boost/shared_ptr.hpp>
00030
00031 #include "index.hh"
00032 #include <ost/config.hh>
00033
00034 namespace ost { namespace qa {
00035
00036 namespace impl {
00037
00038
00039
00040
00041
00042 struct NullType {
00043 template <typename DS>
00044 void Serialize(DS&) {}
00045 };
00046 template <typename T1,
00047 typename T2,
00048 typename T3,
00049 typename T4,
00050 typename T5,
00051 typename T6,
00052 typename T7>
00053 struct LengthOf;
00054
00055 template <typename T1,
00056 typename T2,
00057 typename T3,
00058 typename T4,
00059 typename T5,
00060 typename T6,
00061 typename T7>
00062 struct LengthOf {
00063 enum { Value = 1+LengthOf<T2, T3, T4, T5, T6, T7, NullType>::Value };
00064 };
00065 template <>
00066 struct LengthOf<NullType, NullType, NullType, NullType, NullType, NullType,
00067 NullType>
00068 {
00069 enum { Value = 0 };
00070 };
00071
00073
00074 template <bool C,
00075 typename T,
00076 typename F>
00077 struct If;
00078
00079 template <typename T,
00080 typename F>
00081 struct If<true, T, F> {
00082 typedef T Type;
00083 };
00084
00085 template <typename T,
00086 typename F>
00087 struct If<false, T, F> {
00088 typedef F Type;
00089 };
00090
00092
00093 template <typename T1,
00094 typename T2>
00095 struct IsEqual;
00096 template <typename T1,
00097 typename T2>
00098 struct IsEqual {
00099 enum { Value = false };
00100 };
00101 template <typename T>
00102 struct IsEqual<T,T> {
00103 enum { Value = true };
00104 };
00105
00106
00107
00108
00109
00110 template <typename C, typename T, typename F>
00111 struct IfNull;
00112
00113 template <typename C,
00114 typename T,
00115 typename F>
00116 struct IfNull {
00117 typedef typename If<IsEqual<NullType, C>::Value, T, F>::Type Type;
00118 };
00119
00120 }
00121
00123 class DLLEXPORT_OST_QA ClassifierBase {
00124 public:
00125 ClassifierBase(uint32_t number_of_classes)
00126 : number_of_classes_(number_of_classes) {
00127 }
00128 ClassifierBase()
00129 : number_of_classes_(0) {}
00130 virtual ~ClassifierBase() {}
00131 uint32_t GetNumberOfClasses() const {
00132 return number_of_classes_;
00133 }
00134 protected:
00135 uint32_t number_of_classes_;
00136 };
00137
00139 class DLLEXPORT_OST_QA IntegralClassifier : public ClassifierBase {
00140 public:
00141 IntegralClassifier(uint32_t number_of_classes,
00142 int lower_bound)
00143 : ClassifierBase(number_of_classes),
00144 lower_bound_(lower_bound) {
00145 }
00146 uint32_t GetIndexOf(int value) const {
00147 uint32_t idx=(value-lower_bound_);
00148 assert(this->GetNumberOfClasses()>idx);
00149 return idx;
00150 }
00151 IntegralClassifier()
00152 : ClassifierBase(0),
00153 lower_bound_(0) {
00154 }
00155
00156 template <typename DS>
00157 void Serialize(DS& ds)
00158 {
00159 ds & number_of_classes_;
00160 ds & lower_bound_;
00161 }
00162 private:
00163 int32_t lower_bound_;
00164 };
00165
00167 class DLLEXPORT_OST_QA ContinuousClassifier : public ClassifierBase {
00168 public:
00169 ContinuousClassifier(uint32_t number_of_classes,
00170 Real lower_bound,
00171 Real upper_bound)
00172 : ClassifierBase(number_of_classes),
00173 lower_bound_(lower_bound),
00174 upper_bound_(upper_bound) {
00175 }
00176 uint32_t GetIndexOf(Real value) const {
00177 Real factor=(value-lower_bound_)/(upper_bound_-lower_bound_);
00178 uint32_t idx=uint32_t(floor(this->GetNumberOfClasses()*factor));
00179
00180 assert(this->GetNumberOfClasses()>idx);
00181 return idx;
00182 }
00183 ContinuousClassifier()
00184 : ClassifierBase(1),
00185 lower_bound_(0), upper_bound_(1) {
00186 }
00187 template <typename DS>
00188 void Serialize(DS& ds)
00189 {
00190 ds & number_of_classes_;
00191 ds & lower_bound_;
00192 ds & upper_bound_;
00193 }
00194 private:
00195 Real lower_bound_;
00196 Real upper_bound_;
00197 };
00198
00199
00200 template <typename T>
00201 struct Classifier;
00202
00203 template <>
00204 struct DLLEXPORT_OST_QA Classifier<int> {
00205 typedef IntegralClassifier Type;
00206 typedef const IntegralClassifier& ConstRefType;
00207 typedef IntegralClassifier& RefType;
00208 };
00209 template <>
00210 struct DLLEXPORT_OST_QA Classifier<Real> {
00211 typedef ContinuousClassifier Type;
00212 typedef const ContinuousClassifier& ConstRefType;
00213 typedef ContinuousClassifier& RefType;
00214 };
00215 #if OST_DOUBLE_PRECISION
00216 template <>
00217 struct DLLEXPORT_OST_QA Classifier<float> {
00218 typedef ContinuousClassifier Type;
00219 typedef const ContinuousClassifier& ConstRefType;
00220 typedef ContinuousClassifier& RefType;
00221 };
00222 #endif
00223 template <>
00224 struct DLLEXPORT_OST_QA Classifier<impl::NullType> {
00225 typedef impl::NullType Type;
00226 typedef const impl::NullType& ConstRefType;
00227 typedef impl::NullType& RefType;
00228 };
00229
00230 template <typename I>
00231 struct DLLEXPORT_OST_QA NullFind {
00232 NullFind(const ClassifierBase&,uint32_t,const impl::NullType&,I&) {};
00233 };
00234 template <typename C, typename T, typename I>
00235 struct IndexFind;
00236
00237 template <typename C, typename I>
00238 struct DLLEXPORT_OST_QA IndexFind<C,impl::NullType,I> {
00239 IndexFind(const C&,
00240 uint32_t,
00241 const impl::NullType&, I&) {
00242 }
00243 };
00244
00245 template <typename C, typename T, typename I>
00246 struct DLLEXPORT_OST_QA IndexFind {
00247 IndexFind(const C& classifier, uint32_t i, const T& value, I& index) {
00248 index[i]=classifier.GetIndexOf(value);
00249 }
00250 };
00251 template <typename T>
00252 struct NumberOfClasses;
00253
00254
00255 template <>
00256 struct DLLEXPORT_OST_QA NumberOfClasses<impl::NullType> {
00257 uint32_t operator ()(const impl::NullType& t) {
00258 return 1;
00259 }
00260 };
00261
00262 template <typename T>
00263 struct DLLEXPORT_OST_QA NumberOfClasses {
00264 uint32_t operator ()(const T& t) {
00265 return t.GetNumberOfClasses();
00266 }
00267 };
00268
00270 template <typename V, typename T1,
00271 typename T2=impl::NullType,
00272 typename T3=impl::NullType,
00273 typename T4=impl::NullType,
00274 typename T5=impl::NullType,
00275 typename T6=impl::NullType,
00276 typename T7=impl::NullType>
00277 class DLLEXPORT_OST_QA MultiClassifier {
00278 public:
00279 enum { Dimension = impl::LengthOf<T1, T2, T3, T4, T5, T6, T7>::Value };
00280 typedef V ValueType;
00281 typedef Index<MultiClassifier::Dimension> IndexType;
00282 typedef IndexIterator<Dimension> Iterator;
00283 typedef Classifier<T1> C1;
00284 typedef Classifier<T2> C2;
00285 typedef Classifier<T3> C3;
00286 typedef Classifier<T4> C4;
00287 typedef Classifier<T5> C5;
00288 typedef Classifier<T6> C6;
00289 typedef Classifier<T7> C7;
00290 #if WIN32
00291 MultiClassifier(const V& initial_value,
00292 typename C1::ConstRefType c1,
00293 typename C2::ConstRefType c2=C2::Type(),
00294 typename C3::ConstRefType c3=C3::Type(),
00295 typename C4::ConstRefType c4=C4::Type(),
00296 typename C5::ConstRefType c5=C5::Type(),
00297 typename C6::ConstRefType c6=C6::Type(),
00298 typename C7::ConstRefType c7=C7::Type())
00299 #else
00300 MultiClassifier(const V& initial_value,
00301 typename C1::ConstRefType c1,
00302 typename C2::ConstRefType c2=typename C2::Type(),
00303 typename C3::ConstRefType c3=typename C3::Type(),
00304 typename C4::ConstRefType c4=typename C4::Type(),
00305 typename C5::ConstRefType c5=typename C5::Type(),
00306 typename C6::ConstRefType c6=typename C6::Type(),
00307 typename C7::ConstRefType c7=typename C7::Type())
00308 #endif
00309 : classifier1_(c1), classifier2_(c2), classifier3_(c3),
00310 classifier4_(c4), classifier5_(c5), classifier6_(c6),
00311 classifier7_(c7) {
00312 this->ExtractNumberOfClasses();
00313
00314 uint32_t total=this->CalculateNumberOfBuckets();
00315 buckets_.resize(total, initial_value);
00316 }
00317
00318 MultiClassifier()
00319 {
00320 memset(number_of_classes_, 0, sizeof(number_of_classes_));
00321 }
00322
00323 template <typename DS>
00324 void Serialize(DS& ds)
00325 {
00326 ds & classifier1_;
00327 ds & classifier2_;
00328 ds & classifier3_;
00329 ds & classifier4_;
00330 ds & classifier5_;
00331 ds & classifier6_;
00332 ds & classifier7_;
00333 if (ds.IsSource()) {
00334 this->ExtractNumberOfClasses();
00335 }
00336 ds & buckets_;
00337 }
00338
00339 MultiClassifier(const MultiClassifier& rhs)
00340 : classifier1_(rhs.classifier1_), classifier2_(rhs.classifier2_),
00341 classifier3_(rhs.classifier3_), classifier4_(rhs.classifier4_),
00342 classifier5_(rhs.classifier5_), classifier6_(rhs.classifier6_),
00343 classifier7_(rhs.classifier7_) {
00344 this->ExtractNumberOfClasses();
00345 uint32_t total=this->CalculateNumberOfBuckets();
00346 buckets_.resize(total);
00347 memcpy(&buckets_.front(), &rhs.buckets_.front(), sizeof(V)*total);
00348 }
00349
00350 uint32_t GetBucketCount() const {
00351 return static_cast<uint32_t>(buckets_.size());
00352 }
00353
00354 void Add(const ValueType& value,
00355 T1 x1=T1(), T2 x2=T2(),
00356 T3 x3=T3(), T4 x4=T4(),
00357 T5 x5=T5(), T6 x6=T6(),
00358 T7 x7=T7()) {
00359 IndexType index=this->FindBucket(x1, x2, x3, x4, x5, x6, x7);
00360 uint32_t linear_index=this->LinearizeBucketIndex(index);
00361 buckets_[linear_index]+=value;
00362 }
00363
00364 const ValueType& Get(T1 x1=T1(), T2 x2=T2(),
00365 T3 x3=T3(), T4 x4=T4(),
00366 T5 x5=T5(), T6 x6=T6(), T7 x7=T7()) const {
00367 IndexType index=this->FindBucket(x1, x2, x3, x4, x5, x6, x7);
00368 uint32_t linear_index=this->LinearizeBucketIndex(index);
00369 return buckets_[linear_index];
00370 }
00371
00372 const ValueType& Get(const IndexType& index) const
00373 {
00374 return buckets_[this->LinearizeBucketIndex(index)];
00375 }
00376
00377 void Set(const IndexType& index, const ValueType& value)
00378 {
00379 buckets_[this->LinearizeBucketIndex(index)]=value;
00380 }
00381
00382
00383
00384 IndexType FindBucket(T1 x1=T1(), T2 x2=T2(), T3 x3=T3(),
00385 T4 x4=T4(), T5 x5=T5(), T6 x6=T6(),
00386 T7 x7=T7()) const {
00387
00388
00389 IndexType index;
00390 IndexFind<typename C1::Type, T1,
00391 IndexType> find_index_1(classifier1_, 0, x1, index);
00392 IndexFind<typename C2::Type, T2,
00393 IndexType> find_index_2(classifier2_, 1, x2, index);
00394 IndexFind<typename C3::Type, T3,
00395 IndexType> find_index_3(classifier3_, 2, x3, index);
00396 IndexFind<typename C4::Type, T4,
00397 IndexType> find_index_4(classifier4_, 3, x4, index);
00398 IndexFind<typename C5::Type, T5,
00399 IndexType> find_index_5(classifier5_, 4, x5, index);
00400 IndexFind<typename C6::Type, T6,
00401 IndexType> find_index_6(classifier6_, 5, x6, index);
00402 IndexFind<typename C7::Type, T7,
00403 IndexType> find_index_7(classifier7_, 6, x7, index);
00404 return index;
00405 }
00406
00407 void Add(const ValueType& value, const IndexType& index)
00408 {
00409 buckets_[this->LinearizeBucketIndex(index)]+=value;
00410 }
00411 private:
00412 void ExtractNumberOfClasses()
00413 {
00414 number_of_classes_[0]=NumberOfClasses<typename C1::Type>()(classifier1_);
00415 number_of_classes_[1]=NumberOfClasses<typename C2::Type>()(classifier2_);
00416 number_of_classes_[2]=NumberOfClasses<typename C3::Type>()(classifier3_);
00417 number_of_classes_[3]=NumberOfClasses<typename C4::Type>()(classifier4_);
00418 number_of_classes_[4]=NumberOfClasses<typename C5::Type>()(classifier5_);
00419 number_of_classes_[5]=NumberOfClasses<typename C6::Type>()(classifier6_);
00420 number_of_classes_[6]=NumberOfClasses<typename C7::Type>()(classifier7_);
00421 }
00422
00423 uint32_t LinearizeBucketIndex(const IndexType& index) const
00424 {
00425 uint32_t factor=1;
00426 uint32_t linear_index=0;
00427 for (uint32_t i=0; i<MultiClassifier::Dimension; ++i) {
00428 linear_index+=factor*index[i];
00429 factor*=number_of_classes_[i];
00430 }
00431 return linear_index;
00432 }
00433
00434 uint32_t CalculateNumberOfBuckets() const
00435 {
00436 uint32_t total=1;
00437 for (uint32_t i=0; i<MultiClassifier::Dimension; ++i) {
00438 total*=number_of_classes_[i];
00439 }
00440 return total;
00441 }
00442 typename C1::Type classifier1_;
00443 typename C2::Type classifier2_;
00444 typename C3::Type classifier3_;
00445 typename C4::Type classifier4_;
00446 typename C5::Type classifier5_;
00447 typename C6::Type classifier6_;
00448 typename C7::Type classifier7_;
00449 uint32_t number_of_classes_[7];
00450 std::vector<ValueType> buckets_;
00451 };
00452
00453 }}
00454
00455 #endif