|
|
#pragma once |
|
|
|
|
|
#include <algorithm> |
|
|
|
|
|
#include "representation.h" |
|
|
#include "util.h" |
|
|
|
|
|
#include "../data_structures/encoder_config.h" |
|
|
#include "../data_structures/train_config.h" |
|
|
#include "../data_structures/token_sequence.h" |
|
|
#include "../midi_parsing/midi_io.h" |
|
|
|
|
|
|
|
|
namespace encoder { |
|
|
|
|
|
template<class T> |
|
|
using matrix = std::vector<std::vector<T>>; |
|
|
|
|
|
std::vector<int> resolve_bar_infill_tokens(std::vector<int> &raw_tokens, const std::shared_ptr<REPRESENTATION> &rep) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "resolving bar infill" ); |
|
|
int fill_pholder = rep->encode(midi::TOKEN_FILL_IN_PLACEHOLDER, 0); |
|
|
int fill_start = rep->encode(midi::TOKEN_FILL_IN_START, 0); |
|
|
int fill_end = rep->encode(midi::TOKEN_FILL_IN_END, 0); |
|
|
|
|
|
std::vector<int> tokens; |
|
|
|
|
|
auto start_pholder = raw_tokens.begin(); |
|
|
auto start_fill = raw_tokens.begin(); |
|
|
auto end_fill = raw_tokens.begin(); |
|
|
|
|
|
while (start_pholder != raw_tokens.end()) { |
|
|
start_pholder = next(start_pholder); |
|
|
auto last_start_pholder = start_pholder; |
|
|
start_pholder = find(start_pholder, raw_tokens.end(), fill_pholder); |
|
|
if (start_pholder != raw_tokens.end()) { |
|
|
start_fill = find(next(start_fill), raw_tokens.end(), fill_start); |
|
|
end_fill = find(next(end_fill), raw_tokens.end(), fill_end); |
|
|
|
|
|
|
|
|
tokens.insert(tokens.end(), last_start_pholder, start_pholder); |
|
|
tokens.insert(tokens.end(), next(start_fill), end_fill); |
|
|
} |
|
|
else { |
|
|
|
|
|
start_fill = find(raw_tokens.begin(), raw_tokens.end(), fill_start); |
|
|
tokens.insert(tokens.end(), last_start_pholder, start_fill); |
|
|
} |
|
|
} |
|
|
return tokens; |
|
|
} |
|
|
|
|
|
class ENCODER { |
|
|
public: |
|
|
|
|
|
virtual ~ENCODER() {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> get_attribute_control_types() { |
|
|
std::vector<std::string> types; |
|
|
auto enum_descriptor = google::protobuf::GetEnumDescriptor<midi::ATTRIBUTE_CONTROL_TYPE>(); |
|
|
for (auto c : attribute_control_types) { |
|
|
types.push_back(enum_descriptor->FindValueByNumber(c)->name()); |
|
|
} |
|
|
return types; |
|
|
} |
|
|
|
|
|
virtual void preprocess_piece(midi::Piece *p) { |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<int> encode(midi::Piece *p) { |
|
|
preprocess_piece(p); |
|
|
data_structures::TokenSequence ts = encode_piece(p); |
|
|
return ts.tokens; |
|
|
} |
|
|
|
|
|
std::vector<int> encode_wo_preprocess(midi::Piece *p) { |
|
|
data_structures::TokenSequence ts = encode_piece(p); |
|
|
return ts.tokens; |
|
|
} |
|
|
|
|
|
virtual void decode(std::vector<int> &tokens, midi::Piece *p) { |
|
|
if (config->do_multi_fill == true) { |
|
|
tokens = resolve_bar_infill_tokens(tokens, rep); |
|
|
} |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "AFTER BAR INFILL RESOLVED :: "); |
|
|
for (int tok : tokens) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, pretty(tok)); |
|
|
} |
|
|
decode_track(tokens, p, rep, config); |
|
|
} |
|
|
|
|
|
std::string midi_to_json(const std::string &filepath) { |
|
|
midi::Piece p; |
|
|
midi_io::ParseSong(filepath, &p, config); |
|
|
preprocess_piece(&p); |
|
|
std::string json_string; |
|
|
google::protobuf::util::MessageToJsonString(p, &json_string); |
|
|
return json_string; |
|
|
} |
|
|
|
|
|
void midi_to_piece(const std::string& filepath, midi::Piece* p) { |
|
|
midi_io::ParseSong(filepath, p, config); |
|
|
preprocess_piece(p); |
|
|
} |
|
|
|
|
|
std::vector<int> midi_to_tokens(std::string &filepath) { |
|
|
midi::Piece p; |
|
|
midi_io::ParseSong(filepath, &p, config); |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("Parsed File :: ",util_protobuf::protobuf_to_string(&p))); |
|
|
return encode(&p); |
|
|
} |
|
|
|
|
|
void json_to_midi(std::string &json_string, std::string &filepath) { |
|
|
midi::Piece p; |
|
|
google::protobuf::util::JsonStringToMessage(json_string.c_str(), &p); |
|
|
midi_io::write_midi(&p, filepath, -1); |
|
|
} |
|
|
|
|
|
std::string json_to_json(std::string &json_string_in) { |
|
|
midi::Piece p; |
|
|
google::protobuf::util::JsonStringToMessage(json_string_in.c_str(), &p); |
|
|
std::string json_string; |
|
|
google::protobuf::util::MessageToJsonString(p, &json_string); |
|
|
return json_string; |
|
|
} |
|
|
|
|
|
void json_track_to_midi(std::string &json_string, std::string &filepath, int single_track) { |
|
|
midi::Piece p; |
|
|
google::protobuf::util::JsonStringToMessage(json_string.c_str(), &p); |
|
|
midi_io::write_midi(&p, filepath, single_track); |
|
|
} |
|
|
|
|
|
std::vector<int> json_to_tokens(std::string &json_string) { |
|
|
midi::Piece p; |
|
|
google::protobuf::util::JsonStringToMessage(json_string.c_str(), &p); |
|
|
return encode(&p); |
|
|
} |
|
|
|
|
|
std::string tokens_to_json(std::vector<int> &tokens) { |
|
|
midi::Piece p; |
|
|
decode(tokens, &p); |
|
|
std::string json_string; |
|
|
google::protobuf::util::MessageToJsonString(p, &json_string); |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("Decoded File :: ",json_string)); |
|
|
return json_string; |
|
|
} |
|
|
|
|
|
void resample_delta(midi::Piece *p) { |
|
|
|
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_VERBOSE, "Resampling Piece with Delta values"); |
|
|
|
|
|
|
|
|
std::map<int, int> delta_to_apply; |
|
|
int track_num = 0; |
|
|
for (const auto &track : p->tracks()) { |
|
|
int bar_num = 0; |
|
|
for (const auto &bar : track.bars()) { |
|
|
std::map<int, std::vector<std::tuple<int, int, int, int>>> pitch_to_events; |
|
|
for (int i=0; i<bar.events_size(); i++) { |
|
|
int event_idx = bar.events()[i]; |
|
|
midi::Event event = p->events(event_idx); |
|
|
pitch_to_events[event.pitch()].push_back(std::make_tuple(event_idx, event.time(), event.velocity(), event.delta())); |
|
|
} |
|
|
for (auto line : pitch_to_events) { |
|
|
std::sort(line.second.begin(), line.second.end(), [](std::tuple<int, int, int, int> a, std::tuple<int, int, int, int> b) { |
|
|
if (std::get<1>(a) < std::get<1>(b)) return true; |
|
|
if (std::get<1>(b) < std::get<1>(a)) return false; |
|
|
return (std::get<2>(a) < std::get<2>(b)); |
|
|
}); |
|
|
std::tuple<int, int, int> last_event; |
|
|
int last_offset_idx = -1; |
|
|
for (auto const& e : line.second) { |
|
|
|
|
|
if ((std::get<2>(e) > 0) && (last_offset_idx != -1)) { |
|
|
if ((std::get<3>(e) != 0) && (std::get<1>(e) == p->events(last_offset_idx).time())) { |
|
|
delta_to_apply[last_offset_idx] = std::get<3>(e); |
|
|
} |
|
|
} else if (std::get<2>(e) == 0) { |
|
|
last_offset_idx = std::get<0>(e); |
|
|
} |
|
|
} |
|
|
} |
|
|
bar_num++; |
|
|
} |
|
|
track_num++; |
|
|
} |
|
|
|
|
|
int current_res = config->resolution; |
|
|
int target_res = config->decode_resolution; |
|
|
p->set_resolution(target_res); |
|
|
p->set_internal_ticks_per_quarter(target_res); |
|
|
int old_time, new_time, delta; |
|
|
std::vector<std::tuple<int, midi::Event>> events_cache; |
|
|
|
|
|
|
|
|
int num_events = p->events_size(); |
|
|
for (int event_index=0; event_index<num_events; event_index++) { |
|
|
midi::Event e = p->events(event_index); |
|
|
old_time = e.time(); |
|
|
delta = e.delta(); |
|
|
if (delta_to_apply.count(event_index) > 0) { |
|
|
assert(delta_to_apply.count(event_index) == 1); |
|
|
delta = delta_to_apply[event_index]; |
|
|
} |
|
|
|
|
|
new_time = (int)(target_res * old_time / current_res); |
|
|
|
|
|
new_time = std::max(new_time + delta, 0); |
|
|
|
|
|
e.set_time(new_time); |
|
|
events_cache.push_back(std::make_tuple(event_index, e)); |
|
|
} |
|
|
|
|
|
sort(events_cache.begin(), events_cache.end(), [](std::tuple<int, midi::Event> a, std::tuple<int, midi::Event> b) { |
|
|
return std::get<0>(a) < std::get<0>(b); |
|
|
}); |
|
|
|
|
|
p->clear_events(); |
|
|
|
|
|
for (const std::tuple<int, midi::Event> &oe : events_cache) { |
|
|
midi::Event *ne = p->add_events(); |
|
|
ne->CopyFrom( std::get<1>(oe) ); |
|
|
} |
|
|
assert(num_events == p->events_size()); |
|
|
} |
|
|
|
|
|
std::string resample_delta_json(std::string &json_string) { |
|
|
std::string res_json_string; |
|
|
midi::Piece p; |
|
|
google::protobuf::util::JsonStringToMessage(json_string.c_str(), &p); |
|
|
if (config->use_microtiming) { |
|
|
resample_delta(&p); |
|
|
} |
|
|
google::protobuf::util::MessageToJsonString(p, &res_json_string); |
|
|
return res_json_string; |
|
|
} |
|
|
|
|
|
void tokens_to_json_array(std::vector<std::vector<int>> &seqs, std::vector<midi::Piece> &output) { |
|
|
for (int i=0; i<(int)seqs.size(); i++) { |
|
|
decode(seqs[i], &(output[i])); |
|
|
} |
|
|
} |
|
|
|
|
|
void tokens_to_midi(std::vector<int> &tokens, std::string &filepath) { |
|
|
midi::Piece p; |
|
|
decode(tokens, &p); |
|
|
midi_io::write_midi(&p, filepath, -1); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::string pretty(int token) { |
|
|
return rep->pretty(token); |
|
|
} |
|
|
|
|
|
int vocab_size() { |
|
|
return rep->vocab_size; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void encode_notes(int bar_num, int track_num, midi::Piece *p, data_structures::TokenSequence *ts) { |
|
|
const auto track = p->tracks(track_num); |
|
|
const auto bar = track.bars(bar_num); |
|
|
const auto is_drum = data_structures::is_drum_track(track.track_type()); |
|
|
const int N_DURATION_TOKENS = rep->get_domain_size(midi::TOKEN_NOTE_DURATION); |
|
|
int N_TIME_TOKENS = rep->get_domain_size(midi::TOKEN_DELTA); |
|
|
|
|
|
|
|
|
std::vector<int> onsets; |
|
|
std::vector<int> onsets_idx; |
|
|
std::map<int, std::vector<int>> notes_by_onset; |
|
|
std::map<int, int> delta_onsets; |
|
|
int idx = 0; |
|
|
for (const auto &i : bar.events()) { |
|
|
midi::Event event = p->events(i); |
|
|
if ((event.internal_duration() > 0) && (event.velocity() > 0)) { |
|
|
if (notes_by_onset.find(event.time()) == notes_by_onset.end()) { |
|
|
onsets.push_back(event.time()); |
|
|
onsets_idx.push_back(idx); |
|
|
idx += 1; |
|
|
} |
|
|
notes_by_onset[event.time()].push_back(i); |
|
|
delta_onsets[i] = event.delta(); |
|
|
} |
|
|
} |
|
|
|
|
|
int last_velocity = -1; |
|
|
int onset; |
|
|
int d_onset; |
|
|
for (const auto &idx : onsets_idx) { |
|
|
onset = onsets[idx]; |
|
|
|
|
|
|
|
|
if ((onset > 0)) { |
|
|
ts->push_back( rep->encode(midi::TOKEN_TIME_ABSOLUTE_POS, onset) ); |
|
|
} |
|
|
|
|
|
for (const auto &i : notes_by_onset[onset]) { |
|
|
midi::Event event = p->events(i); |
|
|
d_onset = delta_onsets[i]; |
|
|
if (rep->has_token_type(midi::TOKEN_VELOCITY_LEVEL)) { |
|
|
int current_velocity = rep->encode_partial(midi::TOKEN_VELOCITY_LEVEL, event.velocity()); |
|
|
if ((current_velocity > 0) && (current_velocity != last_velocity)) { |
|
|
ts->push_back( rep->encode(midi::TOKEN_VELOCITY_LEVEL, event.velocity()) ); |
|
|
last_velocity = current_velocity; |
|
|
} |
|
|
} |
|
|
if (config->use_microtiming) { |
|
|
if (d_onset < 0) { |
|
|
ts->push_back( rep->encode(midi::TOKEN_DELTA_DIRECTION, 0) ); |
|
|
d_onset *= -1; |
|
|
} |
|
|
d_onset = std::min(N_TIME_TOKENS - 1, d_onset); |
|
|
if (d_onset > 0) { |
|
|
ts->push_back( rep->encode(midi::TOKEN_DELTA, d_onset) ); |
|
|
} |
|
|
} |
|
|
ts->push_back( rep->encode(midi::TOKEN_NOTE_ONSET, event.pitch()) ); |
|
|
if (!is_drum) { |
|
|
ts->push_back( rep->encode(midi::TOKEN_NOTE_DURATION, std::min(event.internal_duration(), N_DURATION_TOKENS)-1) ); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
void encode_bar(int bar_num, int track_num, midi::Piece *p, data_structures::TokenSequence *ts, bool infill) { |
|
|
auto track = p->tracks(track_num); |
|
|
const auto bar = track.bars(bar_num); |
|
|
const auto is_drum = data_structures::is_drum_track(track.track_type()); |
|
|
|
|
|
ts->on_bar_start(p, rep); |
|
|
|
|
|
if (infill) { |
|
|
ts->push_back( rep->encode(midi::TOKEN_FILL_IN_START, 0) ); |
|
|
encode_notes(bar_num, track_num, p, ts); |
|
|
ts->push_back( rep->encode(midi::TOKEN_FILL_IN_END, 0) ); |
|
|
} |
|
|
else { |
|
|
ts->push_back( rep->encode(midi::TOKEN_BAR, 0) ); |
|
|
|
|
|
midi::BarFeatures *bf = util_protobuf::GetBarFeatures(&track, bar_num); |
|
|
append_bar_tokens(ts, rep, bf, is_drum); |
|
|
|
|
|
if (rep->has_token_type(midi::TOKEN_TIME_SIGNATURE)) { |
|
|
ts->push_back( rep->encode(midi::TOKEN_TIME_SIGNATURE, std::make_tuple(bar.ts_numerator(), bar.ts_denominator())) ); |
|
|
} |
|
|
|
|
|
if ((config->do_multi_fill) && (config->multi_fill.find(std::make_pair(track_num,bar_num)) != config->multi_fill.end())) { |
|
|
ts->push_back( rep->encode(midi::TOKEN_FILL_IN_PLACEHOLDER, 0) ); |
|
|
} |
|
|
else { |
|
|
encode_notes(bar_num, track_num, p, ts); |
|
|
} |
|
|
ts->push_back( rep->encode(midi::TOKEN_BAR_END, 0) ); |
|
|
} |
|
|
} |
|
|
|
|
|
void encode_track(int track_num, midi::Piece *p, data_structures::TokenSequence *ts) { |
|
|
const auto track = p->tracks(track_num); |
|
|
const auto is_drum = data_structures::is_drum_track(track.track_type()); |
|
|
const auto f = util_protobuf::GetTrackFeatures(p, track_num); |
|
|
|
|
|
ts->on_track_start(p, rep); |
|
|
|
|
|
ts->push_back( rep->encode(midi::TOKEN_TRACK, track.track_type()) ); |
|
|
|
|
|
append_track_pre_instrument_tokens(ts, rep, f, is_drum); |
|
|
|
|
|
if (rep->has_token_type(midi::TOKEN_INSTRUMENT)) { |
|
|
int inst = track.instrument(); |
|
|
ts->push_back( rep->encode(midi::TOKEN_INSTRUMENT, inst) ); |
|
|
} |
|
|
|
|
|
append_track_tokens(ts, rep, f, is_drum); |
|
|
|
|
|
for (int i=0; i<track.bars_size(); i++) { |
|
|
encode_bar(i, track_num, p, ts, false); |
|
|
} |
|
|
|
|
|
ts->push_back( rep->encode(midi::TOKEN_TRACK_END, 0) ); |
|
|
} |
|
|
|
|
|
data_structures::TokenSequence encode_piece(midi::Piece *p) { |
|
|
|
|
|
|
|
|
if ((!rep->has_token_type(midi::TOKEN_NOTE_DURATION)) || (!rep->has_token_type(midi::TOKEN_TIME_ABSOLUTE_POS))) { |
|
|
throw std::runtime_error("ERROR: ENCODING PIECE WITH DEPRECATED NOTE ENCODINGS"); |
|
|
} |
|
|
|
|
|
data_structures::TokenSequence ts(rep); |
|
|
|
|
|
ts.push_back( rep->encode( |
|
|
midi::TOKEN_PIECE_START, std::min((int)config->do_multi_fill,rep->get_domain_size(midi::TOKEN_PIECE_START)-1))); |
|
|
|
|
|
if (rep->has_token_type(midi::TOKEN_NUM_BARS)) { |
|
|
ts.push_back( rep->encode(midi::TOKEN_NUM_BARS, util_protobuf::GetNumBars(p)) ); |
|
|
} |
|
|
|
|
|
for (int i=0; i<p->tracks_size(); i++) { |
|
|
encode_track(i, p, &ts); |
|
|
} |
|
|
|
|
|
if (config->do_multi_fill) { |
|
|
for (const auto &track_bar : config->multi_fill) { |
|
|
encode_bar(std::get<1>(track_bar), std::get<0>(track_bar), p, &ts, true); |
|
|
} |
|
|
} |
|
|
|
|
|
return ts; |
|
|
} |
|
|
|
|
|
std::shared_ptr<REPRESENTATION> get_rep() { |
|
|
return rep; |
|
|
} |
|
|
|
|
|
std::shared_ptr<data_structures::EncoderConfig> config; |
|
|
std::shared_ptr<REPRESENTATION> rep; |
|
|
std::vector<midi::ATTRIBUTE_CONTROL_TYPE> attribute_control_types; |
|
|
}; |
|
|
|
|
|
} |
|
|
|
|
|
|