/* * Project: RooFit * Authors: * Garima Singh, CERN 2023 * Jonas Rembser, CERN 2023 * * Copyright (c) 2023, CERN * * Redistribution and use in source and binary forms, * with or without modification, are permitted according to the terms * listed in LICENSE (http://roofit.sourceforge.net/license.txt) */ #ifndef RooFit_Detail_CodeSquashContext_h #define RooFit_Detail_CodeSquashContext_h #include #include #include #include #include #include #include #include #include template class RooTemplateProxy; namespace RooFit { namespace Detail { /// @brief A class to maintain the context for squashing of RooFit models into code. class CodeSquashContext { public: CodeSquashContext(std::map const &outputSizes) : _nodeOutputSizes(outputSizes) { } void addResult(RooAbsArg const *key, std::string const &value); void addResult(const char *key, std::string const &value); std::string const &getResult(RooAbsArg const &arg); template std::string const &getResult(RooTemplateProxy const &key) { return getResult(key.arg()); } /// @brief Figure out the output size of a node. It is the size of the /// vector observable that it depends on, or 1 if it doesn't depend on any /// or is a reducer node. /// @param key The node to look up the size for. std::size_t outputSize(RooFit::Detail::DataKey key) const { auto found = _nodeOutputSizes.find(key); if (found != _nodeOutputSizes.end()) return found->second; return 1; } void addToGlobalScope(std::string const &str); std::string assembleCode(std::string const &returnExpr); void addVecObs(const char *key, int idx); void addToCodeBody(RooAbsArg const *klass, std::string const &in); void addToCodeBody(std::string const &in, bool isScopeIndep = false); /// @brief Build the code to call the function with name `funcname`, passing some arguments. /// The arguments can either be doubles or some RooFit arguments whose /// results will be looked up in the context. template std::string buildCall(std::string const &funcname, Args_t const &...args) { std::stringstream ss; ss << funcname << "(" << buildArgs(args...) << ")"; return ss.str(); } /// @brief A class to manage loop scopes using the RAII technique. To wrap your code around a loop, /// simply place it between a brace inclosed scope with a call to beginLoop at the top. For e.g. /// { /// auto scope = ctx.beginLoop({<-set of vector observables to loop over->}); /// // your loop body code goes here. /// } class LoopScope { public: LoopScope(CodeSquashContext &ctx, std::vector &&vars) : _ctx{ctx}, _vars{vars} {} ~LoopScope() { _ctx.endLoop(*this); } std::vector const &vars() const { return _vars; } private: CodeSquashContext &_ctx; const std::vector _vars; }; std::unique_ptr beginLoop(RooAbsArg const *in); std::string getTmpVarName(); std::string buildArg(RooAbsCollection const &x); std::string buildArg(std::span arr); private: bool isScopeIndependent(RooAbsArg const *in) const; void endLoop(LoopScope const &scope); void addResult(TNamed const *key, std::string const &value); template {}, bool>::type = true> std::string buildArg(T x) { return RooNumber::toString(x); } // If input is integer, we want to print it into the code like one (i.e. avoid the unnecessary '.0000'). template {}, bool>::type = true> std::string buildArg(T x) { return std::to_string(x); } std::string buildArg(std::string const &x) { return x; } std::string buildArg(std::nullptr_t) { return "nullptr"; } std::string buildArg(RooAbsArg const &arg) { return getResult(arg); } template std::string buildArg(RooTemplateProxy const &arg) { return getResult(arg); } std::string buildArgs() { return ""; } template std::string buildArgs(Arg_t const &arg) { return buildArg(arg); } template std::string buildArgs(Arg_t const &arg, Args_t const &...args) { return buildArg(arg) + ", " + buildArgs(args...); } /// @brief Map of node names to their result strings. std::unordered_map _nodeNames; /// @brief Block of code that is placed before the rest of the function body. std::string _globalScope; /// @brief A map to keep track of the observable indices if they are non scalar. std::unordered_map _vecObsIndices; /// @brief Map of node output sizes. std::map _nodeOutputSizes; /// @brief Stores the squashed code body. std::string _code; /// @brief The current number of for loops the started. int _loopLevel = 0; /// @brief Index to get unique names for temporary variables. int _tmpVarIdx = 0; /// @brief Keeps track of the position to go back and insert code to. int _scopePtr = -1; /// @brief Stores code that eventually gets injected into main code body. /// Mainly used for placing decls outside of loops. std::string _tempScope; /// @brief A map to keep track of list names as assigned by addResult. std::unordered_map::Value_t, std::string> listNames; }; } // namespace Detail } // namespace RooFit #endif