File size: 3,098 Bytes
6229e10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#pragma once

#include <chrono>
#include <ctime>

namespace sampling {

  // Base class for callbacks
  class CallbackBase {
  public:
    CallbackBase () { }
    virtual ~CallbackBase () { }
    virtual void on_bar_end () {}
    virtual void on_prediction (std::vector<float> &logits, int next_token) {}
    virtual void on_start () {}
    virtual float update_temperature(float current_temperature) {
      return current_temperature;
    }
    virtual bool is_cancelled() {
      return false;
    }
  };

  // Class that manages call all callbacks
  class CallbackManager {
  public:
    CallbackManager () {}
    ~CallbackManager () {}
    void add_callback_ptr(std::shared_ptr<CallbackBase> x) {
      callbacks.push_back(x);
    }
    void on_bar_end () {
      for (auto &x : callbacks) {
        x->on_bar_end();
      }
    }
    void on_prediction (std::vector<float> &logits, int next_token) {
      for (auto &x : callbacks) {
        x->on_prediction(logits, next_token);
      }
    }
    void on_start () {
      for (auto &x : callbacks) {
        x->on_start();
      }
    }
    float update_temperature (float current_temperature) {
      for (auto &x : callbacks) {
        float value = x->update_temperature(current_temperature);
        if (value > current_temperature) {
          return value;
        }
      }
      return current_temperature;
    }
    bool is_cancelled() {
      for (auto &x : callbacks) {
        if (x->is_cancelled()) {
          return true;
        }
      }
      return false;
    }
    std::vector<std::shared_ptr<CallbackBase>> callbacks;
  };


  // Callback examples
  class TemperatureIncreaseCallback : public CallbackBase {
  public:
    TemperatureIncreaseCallback (float _increase, float _current_temperature) {
      increase = _increase;
      current_temperature = _current_temperature;
    }
    float update_temperature(float temp) {
      current_temperature = temp + increase;
      std::cout << "CURRENT TEMPERATURE : " << current_temperature << std::endl;
      return current_temperature;
    }
    float increase;
    float current_temperature;
  };


  class LogLikelihoodCallback : public CallbackBase {
  public:
    LogLikelihoodCallback () {
      loglik = 0;
      sequence_length = 0;
    }
    void on_prediction(std::vector<float> &logits, int next_token) {
      loglik += logits[next_token];
      sequence_length++;
    }
    void on_start() {
      loglik = 0;
      sequence_length = 0;
    }
    double loglik;
    int sequence_length;
  };

  class RecordTokenSequenceCallback : public CallbackBase {
  public:
    RecordTokenSequenceCallback () {}
    void on_start() {
      tokens.clear();
    }
    void on_prediction(std::vector<float> &logits, int next_token) {
      tokens.push_back(next_token);
    }
    std::vector<int> tokens;
  };

  class CancelCallback : public CallbackBase {
  public:
    CancelCallback () {
      cancel = false;
    }
    void set_cancel(bool cancel_value) {
      cancel = cancel_value;
    }
    bool is_cancelled() {
      return cancel;
    }
    bool cancel;
  };

}