#pragma once #include #include "representation.h" #include "util.h" #include "../data_structures/encoder_config.h" #include "../data_structures/train_config.h" #include "../data_structures/token_sequence.h" #include "../midi_parsing/midi_io.h" // START OF NAMESPACE namespace encoder { template using matrix = std::vector>; std::vector resolve_bar_infill_tokens(std::vector &raw_tokens, const std::shared_ptr &rep) { data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "resolving bar infill" ); int fill_pholder = rep->encode(midi::TOKEN_FILL_IN_PLACEHOLDER, 0); int fill_start = rep->encode(midi::TOKEN_FILL_IN_START, 0); int fill_end = rep->encode(midi::TOKEN_FILL_IN_END, 0); std::vector tokens; auto start_pholder = raw_tokens.begin(); auto start_fill = raw_tokens.begin(); auto end_fill = raw_tokens.begin(); while (start_pholder != raw_tokens.end()) { start_pholder = next(start_pholder); // FIRST TOKEN IS PIECE_START ANYWAYS auto last_start_pholder = start_pholder; start_pholder = find(start_pholder, raw_tokens.end(), fill_pholder); if (start_pholder != raw_tokens.end()) { start_fill = find(next(start_fill), raw_tokens.end(), fill_start); end_fill = find(next(end_fill), raw_tokens.end(), fill_end); // insert from last_start_pholder --> start_pholder tokens.insert(tokens.end(), last_start_pholder, start_pholder); tokens.insert(tokens.end(), next(start_fill), end_fill); } else { // insert from last_start_pholder --> end of sequence (excluding fill) start_fill = find(raw_tokens.begin(), raw_tokens.end(), fill_start); tokens.insert(tokens.end(), last_start_pholder, start_fill); } } return tokens; } class ENCODER { public: virtual ~ENCODER() {} // helper for simplicity // also used to keep track of attribute controls used .... std::vector get_attribute_control_types() { std::vector types; auto enum_descriptor = google::protobuf::GetEnumDescriptor(); for (auto c : attribute_control_types) { types.push_back(enum_descriptor->FindValueByNumber(c)->name()); } return types; } virtual void preprocess_piece(midi::Piece *p) { // default is to do nothing } std::vector encode(midi::Piece *p) { preprocess_piece(p); data_structures::TokenSequence ts = encode_piece(p); return ts.tokens; } std::vector encode_wo_preprocess(midi::Piece *p) { data_structures::TokenSequence ts = encode_piece(p); return ts.tokens; } virtual void decode(std::vector &tokens, midi::Piece *p) { if (config->do_multi_fill == true) { tokens = resolve_bar_infill_tokens(tokens, rep); } data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "AFTER BAR INFILL RESOLVED :: "); for (int tok : tokens) { data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, pretty(tok)); } decode_track(tokens, p, rep, config); } std::string midi_to_json(const std::string &filepath) { midi::Piece p; midi_io::ParseSong(filepath, &p, config); preprocess_piece(&p); // add features that the encoder may need std::string json_string; google::protobuf::util::MessageToJsonString(p, &json_string); return json_string; } void midi_to_piece(const std::string& filepath, midi::Piece* p) { midi_io::ParseSong(filepath, p, config); preprocess_piece(p); } std::vector midi_to_tokens(std::string &filepath) { midi::Piece p; midi_io::ParseSong(filepath, &p, config); data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("Parsed File :: ",util_protobuf::protobuf_to_string(&p))); return encode(&p); } void json_to_midi(std::string &json_string, std::string &filepath) { midi::Piece p; google::protobuf::util::JsonStringToMessage(json_string.c_str(), &p); midi_io::write_midi(&p, filepath, -1); } std::string json_to_json(std::string &json_string_in) { midi::Piece p; google::protobuf::util::JsonStringToMessage(json_string_in.c_str(), &p); std::string json_string; google::protobuf::util::MessageToJsonString(p, &json_string); return json_string; } void json_track_to_midi(std::string &json_string, std::string &filepath, int single_track) { midi::Piece p; google::protobuf::util::JsonStringToMessage(json_string.c_str(), &p); midi_io::write_midi(&p, filepath, single_track); } std::vector json_to_tokens(std::string &json_string) { midi::Piece p; google::protobuf::util::JsonStringToMessage(json_string.c_str(), &p); return encode(&p); } std::string tokens_to_json(std::vector &tokens) { midi::Piece p; decode(tokens, &p); std::string json_string; google::protobuf::util::MessageToJsonString(p, &json_string); data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("Decoded File :: ",json_string)); return json_string; } void resample_delta(midi::Piece *p) { // This function rewrites the piece events time values to take in account their delta values data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_VERBOSE, "Resampling Piece with Delta values"); //We have to deal with overlapping notes by applying next notes onset delta to previous notes offset std::map delta_to_apply; int track_num = 0; for (const auto &track : p->tracks()) { int bar_num = 0; for (const auto &bar : track.bars()) { std::map>> pitch_to_events; for (int i=0; ievents(event_idx); pitch_to_events[event.pitch()].push_back(std::make_tuple(event_idx, event.time(), event.velocity(), event.delta())); } for (auto line : pitch_to_events) { std::sort(line.second.begin(), line.second.end(), [](std::tuple a, std::tuple b) { if (std::get<1>(a) < std::get<1>(b)) return true; if (std::get<1>(b) < std::get<1>(a)) return false; return (std::get<2>(a) < std::get<2>(b)); }); std::tuple last_event; int last_offset_idx = -1; for (auto const& e : line.second) { // if onset, check last offset if ((std::get<2>(e) > 0) && (last_offset_idx != -1)) { if ((std::get<3>(e) != 0) && (std::get<1>(e) == p->events(last_offset_idx).time())) { delta_to_apply[last_offset_idx] = std::get<3>(e); } } else if (std::get<2>(e) == 0) { last_offset_idx = std::get<0>(e); } } } bar_num++; } track_num++; } int current_res = config->resolution; int target_res = config->decode_resolution; p->set_resolution(target_res); p->set_internal_ticks_per_quarter(target_res); int old_time, new_time, delta; std::vector> events_cache; // Get all events and store in cache vector int num_events = p->events_size(); for (int event_index=0; event_indexevents(event_index); old_time = e.time(); delta = e.delta(); if (delta_to_apply.count(event_index) > 0) { assert(delta_to_apply.count(event_index) == 1); delta = delta_to_apply[event_index]; } // We round down to be safe new_time = (int)(target_res * old_time / current_res); //exclude negative times new_time = std::max(new_time + delta, 0); // Set new resampled time e.set_time(new_time); events_cache.push_back(std::make_tuple(event_index, e)); } // Sort events to replace in the correct order sort(events_cache.begin(), events_cache.end(), [](std::tuple a, std::tuple b) { return std::get<0>(a) < std::get<0>(b); }); // Clear all events now that they're cached p->clear_events(); // Reinject resampled events for (const std::tuple &oe : events_cache) { midi::Event *ne = p->add_events(); ne->CopyFrom( std::get<1>(oe) ); } assert(num_events == p->events_size()); } std::string resample_delta_json(std::string &json_string) { std::string res_json_string; midi::Piece p; google::protobuf::util::JsonStringToMessage(json_string.c_str(), &p); if (config->use_microtiming) { resample_delta(&p); } google::protobuf::util::MessageToJsonString(p, &res_json_string); return res_json_string; } void tokens_to_json_array(std::vector> &seqs, std::vector &output) { for (int i=0; i<(int)seqs.size(); i++) { decode(seqs[i], &(output[i])); } } void tokens_to_midi(std::vector &tokens, std::string &filepath) { midi::Piece p; decode(tokens, &p); midi_io::write_midi(&p, filepath, -1); } // ==================== // expose methods of rep that we need std::string pretty(int token) { return rep->pretty(token); } int vocab_size() { return rep->vocab_size; } // ==================== // below is a simplified refactor of the encoding process // broken into clear functions to // - encode notes within a bar // - encode a bar // - encode a track // - encode a piece // ==================== void encode_notes(int bar_num, int track_num, midi::Piece *p, data_structures::TokenSequence *ts) { const auto track = p->tracks(track_num); const auto bar = track.bars(bar_num); const auto is_drum = data_structures::is_drum_track(track.track_type()); const int N_DURATION_TOKENS = rep->get_domain_size(midi::TOKEN_NOTE_DURATION); int N_TIME_TOKENS = rep->get_domain_size(midi::TOKEN_DELTA); // group notes by onset time std::vector onsets; std::vector onsets_idx; std::map> notes_by_onset; std::map delta_onsets; int idx = 0; for (const auto &i : bar.events()) { midi::Event event = p->events(i); if ((event.internal_duration() > 0) && (event.velocity() > 0)) { if (notes_by_onset.find(event.time()) == notes_by_onset.end()) { onsets.push_back(event.time()); onsets_idx.push_back(idx); idx += 1; } notes_by_onset[event.time()].push_back(i); delta_onsets[i] = event.delta(); } } int last_velocity = -1; int onset; int d_onset; for (const auto &idx : onsets_idx) { onset = onsets[idx]; // checking for onset > 0 is to make things backwards compatible with the old representation // however for randomly ordering onset times we need to include onset == 0 if ((onset > 0)) { ts->push_back( rep->encode(midi::TOKEN_TIME_ABSOLUTE_POS, onset) ); } for (const auto &i : notes_by_onset[onset]) { midi::Event event = p->events(i); d_onset = delta_onsets[i]; if (rep->has_token_type(midi::TOKEN_VELOCITY_LEVEL)) { int current_velocity = rep->encode_partial(midi::TOKEN_VELOCITY_LEVEL, event.velocity()); if ((current_velocity > 0) && (current_velocity != last_velocity)) { ts->push_back( rep->encode(midi::TOKEN_VELOCITY_LEVEL, event.velocity()) ); last_velocity = current_velocity; } } if (config->use_microtiming) { if (d_onset < 0) { ts->push_back( rep->encode(midi::TOKEN_DELTA_DIRECTION, 0) ); d_onset *= -1; } d_onset = std::min(N_TIME_TOKENS - 1, d_onset); if (d_onset > 0) { ts->push_back( rep->encode(midi::TOKEN_DELTA, d_onset) ); } } ts->push_back( rep->encode(midi::TOKEN_NOTE_ONSET, event.pitch()) ); if (!is_drum) { ts->push_back( rep->encode(midi::TOKEN_NOTE_DURATION, std::min(event.internal_duration(), N_DURATION_TOKENS)-1) ); } } } } void encode_bar(int bar_num, int track_num, midi::Piece *p, data_structures::TokenSequence *ts, bool infill) { auto track = p->tracks(track_num); const auto bar = track.bars(bar_num); const auto is_drum = data_structures::is_drum_track(track.track_type()); ts->on_bar_start(p, rep); if (infill) { ts->push_back( rep->encode(midi::TOKEN_FILL_IN_START, 0) ); encode_notes(bar_num, track_num, p, ts); ts->push_back( rep->encode(midi::TOKEN_FILL_IN_END, 0) ); } else { ts->push_back( rep->encode(midi::TOKEN_BAR, 0) ); midi::BarFeatures *bf = util_protobuf::GetBarFeatures(&track, bar_num); append_bar_tokens(ts, rep, bf, is_drum); if (rep->has_token_type(midi::TOKEN_TIME_SIGNATURE)) { ts->push_back( rep->encode(midi::TOKEN_TIME_SIGNATURE, std::make_tuple(bar.ts_numerator(), bar.ts_denominator())) ); } if ((config->do_multi_fill) && (config->multi_fill.find(std::make_pair(track_num,bar_num)) != config->multi_fill.end())) { ts->push_back( rep->encode(midi::TOKEN_FILL_IN_PLACEHOLDER, 0) ); } else { encode_notes(bar_num, track_num, p, ts); } ts->push_back( rep->encode(midi::TOKEN_BAR_END, 0) ); } } void encode_track(int track_num, midi::Piece *p, data_structures::TokenSequence *ts) { const auto track = p->tracks(track_num); const auto is_drum = data_structures::is_drum_track(track.track_type()); const auto f = util_protobuf::GetTrackFeatures(p, track_num); ts->on_track_start(p, rep); ts->push_back( rep->encode(midi::TOKEN_TRACK, track.track_type()) ); append_track_pre_instrument_tokens(ts, rep, f, is_drum); if (rep->has_token_type(midi::TOKEN_INSTRUMENT)) { int inst = track.instrument(); ts->push_back( rep->encode(midi::TOKEN_INSTRUMENT, inst) ); } append_track_tokens(ts, rep, f, is_drum); for (int i=0; ipush_back( rep->encode(midi::TOKEN_TRACK_END, 0) ); } data_structures::TokenSequence encode_piece(midi::Piece *p) { // make sure that rep does not try use deprecated note encodings if ((!rep->has_token_type(midi::TOKEN_NOTE_DURATION)) || (!rep->has_token_type(midi::TOKEN_TIME_ABSOLUTE_POS))) { throw std::runtime_error("ERROR: ENCODING PIECE WITH DEPRECATED NOTE ENCODINGS"); } data_structures::TokenSequence ts(rep); ts.push_back( rep->encode( midi::TOKEN_PIECE_START, std::min((int)config->do_multi_fill,rep->get_domain_size(midi::TOKEN_PIECE_START)-1))); if (rep->has_token_type(midi::TOKEN_NUM_BARS)) { ts.push_back( rep->encode(midi::TOKEN_NUM_BARS, util_protobuf::GetNumBars(p)) ); } for (int i=0; itracks_size(); i++) { encode_track(i, p, &ts); } if (config->do_multi_fill) { for (const auto &track_bar : config->multi_fill) { encode_bar(std::get<1>(track_bar), std::get<0>(track_bar), p, &ts, true); } } return ts; } std::shared_ptr get_rep() { return rep; } std::shared_ptr config; std::shared_ptr rep; std::vector attribute_control_types; }; } // END OF NAMESPACE