/* This file is part of MAUS: http://micewww.pp.rl.ac.uk/projects/maus
*
* MAUS is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* MAUS is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with MAUS. If not, see .
*
*/
#include
#include "src/legacy/Interface/Squeal.hh"
#include "src/common_cpp/Utils/MAUSEvaluator.hh"
namespace MAUS {
MAUSEvaluator::MAUSEvaluator() : _evaluator_mod(NULL),
_evaluator(NULL),
_evaluate_func(NULL),
_set_variable_func(NULL) {
Py_Initialize();
this->reset();
}
MAUSEvaluator::~MAUSEvaluator() {
clear();
}
void MAUSEvaluator::set_variable(std::string name, double value) {
PyObject *py_arg = NULL, *py_value = NULL;
// argument to set_variable
py_arg = Py_BuildValue("(sd)", name.c_str(), value);
if (py_arg == NULL) {
PyErr_Clear();
throw(Squeal(Squeal::recoverable,
"Failed to resolve arguments to set_variable",
"MAUSEvaluator::evaluate"));
}
if (_set_variable_func != NULL && PyCallable_Check(_set_variable_func)) {
// py_value should be None (return from Evaluator.set_variable)
py_value = PyObject_CallObject(_set_variable_func, py_arg);
}
// check for error
if (py_value == NULL) {
PyErr_Clear();
Py_DECREF(py_arg);
throw(Squeal(Squeal::recoverable,
"Failed to parse variable "+name,
"MAUSEvaluator::evaluate"));
}
// clean up
Py_DECREF(py_value);
Py_DECREF(py_arg);
}
double MAUSEvaluator::evaluate(std::string function) {
PyObject *py_arg(NULL), *py_value(NULL);
// argument to evalute
py_arg = Py_BuildValue("(s)", function.c_str());
if (py_arg == NULL) {
PyErr_Clear();
throw(Squeal(Squeal::recoverable,
"Failed to build function "+function,
"MAUSEvaluator::evaluate"));
}
// run the evaluator to calculate function value
if (_evaluate_func != NULL && PyCallable_Check(_evaluate_func)) {
py_value = PyObject_CallObject(_evaluate_func, py_arg);
}
if (py_value == NULL) {
PyErr_Clear();
Py_DECREF(py_arg);
throw(Squeal(Squeal::recoverable,
"Failed to evaluate expression \""+function+"\"",
"MAUSEvaluator::evaluate"));
}
// now put transform py_value into C++ double
double value = 0.;
PyArg_Parse(py_value, "d", &value);
// clean up
Py_DECREF(py_value);
Py_DECREF(py_arg);
// return
return value;
}
void MAUSEvaluator::clear() {
// clear any memory allocations
if (_set_variable_func != NULL) {
Py_DECREF(_set_variable_func);
_set_variable_func = NULL;
}
if (_evaluate_func != NULL) {
Py_DECREF(_evaluate_func);
_evaluate_func = NULL;
}
if (_evaluator != NULL) {
Py_DECREF(_evaluator);
_evaluator = NULL;
}
if (_evaluator_mod != NULL) {
Py_DECREF(_evaluator_mod);
_evaluator_mod = NULL;
}
}
void MAUSEvaluator::reset() {
// check that we don't have anything allocated already
clear();
// NOTE: I don't throw Squeals here because I'm nervous about
// set up/tear down order
// initialise evaluator module
_evaluator_mod = PyImport_ImportModule("evaluator");
if (_evaluator_mod == NULL) {
std::cerr << "Failed to import evaluator module" << std::endl;
return;
}
// initialise an evaluator object
// evaluator_mod_dict is evaluator.__dict__ is a borrowed reference
PyObject* evaluator_mod_dict = PyModule_GetDict(_evaluator_mod);
if (evaluator_mod_dict != NULL) {
// evaluator_init is a Evaluator.__init__ borrowed reference
PyObject* evaluator_init = PyDict_GetItemString
(evaluator_mod_dict, "Evaluator");
if (PyCallable_Check(evaluator_init)) {
_evaluator = PyObject_Call(evaluator_init, NULL, NULL);
}
}
if (_evaluator == NULL) {
clear();
std::cerr << "Failed to instantiate evaluator" << std::endl;
return;
}
// get the evaluator object evaluate() function
_evaluate_func = PyObject_GetAttrString(_evaluator, "evaluate");
if (_evaluate_func == NULL) {
clear();
std::cerr << "Failed to find evaluate function" << std::endl;
return;
}
// get the evaluator object set_variable()() function
_set_variable_func = PyObject_GetAttrString(_evaluator, "set_variable");
if (_set_variable_func == NULL) {
clear();
std::cerr << "Failed to find set_variables function" << std::endl;
return;
}
}
} // namespace MAUS