XCSF  1.4.7
XCSF learning classifier system
pybind_callback_checkpoint.h
Go to the documentation of this file.
1 /*
2  * This program is free software: you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation, either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * This program is distributed in the hope that it will be useful,
8  * but WITHOUT ANY WARRANTY; without even the implied warranty of
9  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10  * GNU General Public License for more details.
11  *
12  * You should have received a copy of the GNU General Public License
13  * along with this program. If not, see <http://www.gnu.org/licenses/>.
14  */
15 
24 #pragma once
25 
26 #include <limits>
27 #include <pybind11/numpy.h>
28 #include <pybind11/pybind11.h>
29 
30 namespace py = pybind11;
31 
32 extern "C" {
33 #include "xcsf.h"
34 }
35 
36 #include "pybind_callback.h"
37 #include "pybind_utils.h"
38 
43 {
44  public:
53  CheckpointCallback(py::str monitor, std::string filename,
54  bool save_best_only, int save_freq, bool verbose) :
60  {
61  std::ostringstream err;
62  std::string str = monitor.cast<std::string>();
63  if (str != "train" && str != "val") {
64  err << "invalid metric to monitor: " << str << std::endl;
65  throw std::invalid_argument(err.str());
66  }
67  if (save_freq < 0) {
68  err << "save_freq cannot be negative" << std::endl;
69  throw std::invalid_argument(err.str());
70  }
71  }
72 
77  void
78  save(struct XCSF *xcsf)
79  {
80  xcsf_save(xcsf, filename.c_str());
81  std::ostringstream status;
82  status << get_timestamp() << " CheckpointCallback: ";
83  status << "saved " << filename;
84  py::print(status.str());
85  }
86 
93  bool
94  run(struct XCSF *xcsf, py::dict metrics) override
95  {
96  py::list data = metrics[monitor];
97  py::list trials = metrics["trials"];
98  const double current_error = py::cast<double>(data[data.size() - 1]);
99  const int current_trial = py::cast<int>(trials[trials.size() - 1]);
100  if (current_trial >= save_trial + save_freq) {
101  if (!save_best_only || (current_error < best_error)) {
102  save_trial = current_trial;
103  save(xcsf);
104  }
105  if (current_error < best_error) {
106  best_error = current_error;
107  }
108  }
109  return false;
110  }
111 
116  void
117  finish(struct XCSF *xcsf) override
118  {
119  if (!save_best_only) {
120  save(xcsf);
121  }
122  }
123 
124  private:
125  py::str monitor;
126  std::string filename;
128  int save_freq;
129  bool verbose;
130 
131  double best_error = std::numeric_limits<double>::max();
132  int save_trial = 0;
133 };
Interface for Callbacks.
Callback to save XCSF at some frequency.
bool save_best_only
Whether to only save the best population.
int save_trial
Trial number the last checkpoint was made.
void finish(struct XCSF *xcsf) override
Executes any tasks at the end of fitting.
CheckpointCallback(py::str monitor, std::string filename, bool save_best_only, int save_freq, bool verbose)
Constructs a new checkpoint callback.
bool run(struct XCSF *xcsf, py::dict metrics) override
Performs callback operations.
int save_freq
Trial frequency to (possibly) make checkpoints.
std::string filename
Name of the file to save XCSF.
bool verbose
Whether to display messages when an action is taken.
void save(struct XCSF *xcsf)
Saves the state of XCSF.
py::str monitor
Name of the metric to monitor.
Definition: __init__.py:1
Interface for callbacks.
Utilities for Python library.
std::string get_timestamp()
Returns a formatted string for displaying time.
Definition: pybind_utils.h:31
XCSF data structure.
Definition: xcsf.h:85
size_t xcsf_save(const struct XCSF *xcsf, const char *filename)
Writes the current state of XCSF to a file.
Definition: xcsf.c:90
XCSF data structures.