// @(#)root/tmva $Id$ // Author: Rustem Ospanov /********************************************************************************** * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * * Package: TMVA * * Class : ModulekNN * * Web : http://tmva.sourceforge.net * * * * Description: * * Module for k-nearest neighbor algorithm * * * * Author: * * Rustem Ospanov - U. of Texas at Austin, USA * * * * Copyright (c) 2007: * * CERN, Switzerland * * MPI-K Heidelberg, Germany * * U. of Texas at Austin, USA * * * * Redistribution and use in source and binary forms, with or without * * modification, are permitted according to the terms listed in LICENSE * * (http://tmva.sourceforge.net/LICENSE) * **********************************************************************************/ #ifndef ROOT_TMVA_ModulekNN #define ROOT_TMVA_ModulekNN //______________________________________________________________________ /* kNN::Event describes point in input variable vector-space, with additional functionality like distance between points */ //______________________________________________________________________ // C++ #include #include #include #include #include // ROOT #ifndef ROOT_Rtypes #include "Rtypes.h" #endif #ifndef ROOT_TRandom #include "TRandom3.h" #endif #ifndef ROOT_TMVA_NodekNN #include "TMVA/NodekNN.h" #endif namespace TMVA { class MsgLogger; namespace kNN { typedef Float_t VarType; typedef std::vector VarVec; class Event { public: Event(); Event(const VarVec &vec, Double_t weight, Short_t type); Event(const VarVec &vec, Double_t weight, Short_t type, const VarVec &tvec); ~Event(); Double_t GetWeight() const; VarType GetVar(UInt_t i) const; VarType GetTgt(UInt_t i) const; UInt_t GetNVar() const; UInt_t GetNTgt() const; Short_t GetType() const; // keep these two function separate VarType GetDist(VarType var, UInt_t ivar) const; VarType GetDist(const Event &other) const; void SetTargets(const VarVec &tvec); const VarVec& GetTargets() const; const VarVec& GetVars() const; void Print() const; void Print(std::ostream& os) const; private: VarVec fVar; // coordinates (variables) for knn search VarVec fTgt; // targets for regression analysis Double_t fWeight; // event weight Short_t fType; // event type ==0 or == 1, expand it to arbitrary class types? }; typedef std::vector EventVec; typedef std::pair *, VarType> Elem; typedef std::list List; std::ostream& operator<<(std::ostream& os, const Event& event); class ModulekNN { public: typedef std::map > VarMap; public: ModulekNN(); ~ModulekNN(); void Clear(); void Add(const Event &event); Bool_t Fill(const UShort_t odepth, UInt_t ifrac, const std::string &option = ""); Bool_t Find(Event event, UInt_t nfind = 100, const std::string &option = "count") const; Bool_t Find(UInt_t nfind, const std::string &option) const; const EventVec& GetEventVec() const; const List& GetkNNList() const; const Event& GetkNNEvent() const; const VarMap& GetVarMap() const; const std::map& GetMetric() const; void Print() const; void Print(std::ostream &os) const; private: Node* Optimize(UInt_t optimize_depth); void ComputeMetric(UInt_t ifrac); const Event Scale(const Event &event) const; private: static TRandom3 fgRndm; UInt_t fDimn; Node *fTree; std::map fVarScale; mutable List fkNNList; // latest result from kNN search mutable Event fkNNEvent; // latest event used for kNN search std::map fCount; // count number of events of each type EventVec fEvent; // vector of all events used to build tree and analysis VarMap fVar; // sorted map of variables in each dimension for all event types mutable MsgLogger* fLogger; // message logger MsgLogger& Log() const { return *fLogger; } }; // // inlined functions for Event class // inline VarType Event::GetDist(const VarType var1, const UInt_t ivar) const { const VarType var2 = GetVar(ivar); return (var1 - var2) * (var1 - var2); } inline Double_t Event::GetWeight() const { return fWeight; } inline VarType Event::GetVar(const UInt_t i) const { return fVar[i]; } inline VarType Event::GetTgt(const UInt_t i) const { return fTgt[i]; } inline UInt_t Event::GetNVar() const { return fVar.size(); } inline UInt_t Event::GetNTgt() const { return fTgt.size(); } inline Short_t Event::GetType() const { return fType; } // // inline functions for ModulekNN class // inline const List& ModulekNN::GetkNNList() const { return fkNNList; } inline const Event& ModulekNN::GetkNNEvent() const { return fkNNEvent; } inline const EventVec& ModulekNN::GetEventVec() const { return fEvent; } inline const ModulekNN::VarMap& ModulekNN::GetVarMap() const { return fVar; } inline const std::map& ModulekNN::GetMetric() const { return fVarScale; } } // end of kNN namespace } // end of TMVA namespace #endif