File size: 4,496 Bytes
6229e10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#pragma once

#include "encoder_base.h"
#include "util.h"
#include "attribute_control.h"
#include "../data_structures/track_type.h"
#include "../../inference/enum/velocity.h"
#include "../../inference/enum/timesigs.h"
#include "../../inference/enum/pretrain_group.h"
#include "../midi_parsing/util_protobuf.h"
#include "../../inference/protobuf/validate.h"

// START OF NAMESPACE
namespace encoder {

template <typename T>
std::vector<T> operator+(std::vector<T> const &x, std::vector<T> const &y) {
  std::vector<T> vec;
  vec.reserve(x.size() + y.size());
  vec.insert(vec.end(), x.begin(), x.end());
  vec.insert(vec.end(), y.begin(), y.end());
  return vec;
}

class ExpressiveEncoder : public ENCODER {
public:
  ExpressiveEncoder() {
    config = std::make_shared<data_structures::EncoderConfig>();
    config->both_in_one = true;
    config->use_velocity_levels = true;
    config->use_microtiming = true;
    config->resolution = 12;
    config->delta_resolution = 1920;
    config->decode_resolution = config->delta_resolution;

    rep = std::make_shared<REPRESENTATION>(REPRESENTATION({
      {midi::TOKEN_PIECE_START, TOKEN_DOMAIN(2)},
      {midi::TOKEN_NUM_BARS, TOKEN_DOMAIN({4,8}, INT_VALUES_DOMAIN)},
      {midi::TOKEN_BAR, TOKEN_DOMAIN(1)},
      {midi::TOKEN_BAR_END, TOKEN_DOMAIN(1)},
      {midi::TOKEN_TIME_SIGNATURE, TOKEN_DOMAIN(
        enums::YELLOW_TS_MAP,TIMESIG_MAP_DOMAIN)},
      {midi::TOKEN_TRACK, TOKEN_DOMAIN({
        midi::STANDARD_TRACK,
        midi::STANDARD_DRUM_TRACK  
      },INT_VALUES_DOMAIN)},
      {midi::TOKEN_TRACK_END, TOKEN_DOMAIN(1)},
      {midi::TOKEN_INSTRUMENT, TOKEN_DOMAIN(enums::PRETRAIN_GROUPING,INT_MAP_DOMAIN)},
      {midi::TOKEN_NOTE_ONSET, TOKEN_DOMAIN(128)},
      {midi::TOKEN_NOTE_DURATION, TOKEN_DOMAIN(96)},
      {midi::TOKEN_TIME_ABSOLUTE_POS, TOKEN_DOMAIN(192)},
      {midi::TOKEN_FILL_IN_PLACEHOLDER, TOKEN_DOMAIN(1)},
      {midi::TOKEN_FILL_IN_START, TOKEN_DOMAIN(1)},
      {midi::TOKEN_FILL_IN_END, TOKEN_DOMAIN(1)},
      {midi::TOKEN_DELTA, TOKEN_DOMAIN(96)},
      {midi::TOKEN_DELTA_DIRECTION, TOKEN_DOMAIN(1)},
      {midi::TOKEN_VELOCITY_LEVEL, TOKEN_DOMAIN(128)},

      add_attribute_control_to_representation(midi::TOKEN_MIN_NOTE_DURATION),
      add_attribute_control_to_representation(midi::TOKEN_MAX_NOTE_DURATION),
      add_attribute_control_to_representation(midi::TOKEN_MIN_POLYPHONY),
      add_attribute_control_to_representation(midi::TOKEN_MAX_POLYPHONY),
      add_attribute_control_to_representation(midi::TOKEN_DENSITY_LEVEL),
    }));

  }
  ~ExpressiveEncoder() {}

  void preprocess_piece(midi::Piece *p) {
    util_protobuf::calculate_note_durations(p);
    util_protobuf::update_av_polyphony_and_note_duration(p);
    util_protobuf::update_note_density(p);
  }

  void set_scheme(int res, int delta_res, int delta_vocab_size, int abs_pos_vocab_size) {
    config->resolution = res;
    config->delta_resolution = delta_res;

    rep = std::make_shared<REPRESENTATION>(REPRESENTATION({
      {midi::TOKEN_PIECE_START, TOKEN_DOMAIN(2)},
      {midi::TOKEN_NUM_BARS, TOKEN_DOMAIN({4,8}, INT_VALUES_DOMAIN)},
      {midi::TOKEN_BAR, TOKEN_DOMAIN(1)},
      {midi::TOKEN_BAR_END, TOKEN_DOMAIN(1)},
      {midi::TOKEN_TIME_SIGNATURE, TOKEN_DOMAIN(
        enums::YELLOW_TS_MAP,TIMESIG_MAP_DOMAIN)},
      {midi::TOKEN_TRACK, TOKEN_DOMAIN({
        midi::STANDARD_TRACK,
        midi::STANDARD_DRUM_TRACK  
      },INT_VALUES_DOMAIN)},
      {midi::TOKEN_TRACK_END, TOKEN_DOMAIN(1)},
      {midi::TOKEN_INSTRUMENT, TOKEN_DOMAIN(enums::PRETRAIN_GROUPING,INT_MAP_DOMAIN)},
      {midi::TOKEN_NOTE_ONSET, TOKEN_DOMAIN(128)},
      {midi::TOKEN_NOTE_DURATION, TOKEN_DOMAIN(96)},
      {midi::TOKEN_TIME_ABSOLUTE_POS, TOKEN_DOMAIN(abs_pos_vocab_size)},
      {midi::TOKEN_FILL_IN_PLACEHOLDER, TOKEN_DOMAIN(1)},
      {midi::TOKEN_FILL_IN_START, TOKEN_DOMAIN(1)},
      {midi::TOKEN_FILL_IN_END, TOKEN_DOMAIN(1)},
      {midi::TOKEN_DELTA, TOKEN_DOMAIN(delta_vocab_size)},
      {midi::TOKEN_DELTA_DIRECTION, TOKEN_DOMAIN(1)},

      add_attribute_control_to_representation(midi::TOKEN_MIN_NOTE_DURATION),
      add_attribute_control_to_representation(midi::TOKEN_MAX_NOTE_DURATION),
      add_attribute_control_to_representation(midi::TOKEN_MIN_POLYPHONY),
      add_attribute_control_to_representation(midi::TOKEN_MAX_POLYPHONY),
      add_attribute_control_to_representation(midi::TOKEN_DENSITY_LEVEL),

      {midi::TOKEN_VELOCITY_LEVEL, TOKEN_DOMAIN(128)}
    }));
  }
};

}
// END OF NAMESPACE