|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
|
|
#include <map> |
|
|
#include <tuple> |
|
|
#include <set> |
|
|
#include <vector> |
|
|
|
|
|
#include "../../common/encoder/representation.h" |
|
|
#include "../../common/encoder/encoder_base.h" |
|
|
#include "../enum/encoder_types.h" |
|
|
#include "../../common/encoder/attribute_control.h" |
|
|
#include "../../common/data_structures/verbosity.h" |
|
|
#include "graph.h" |
|
|
|
|
|
namespace sampling { |
|
|
|
|
|
using TOKEN_EDGE = std::pair<midi::TOKEN_TYPE,midi::TOKEN_TYPE>; |
|
|
using CKPT_MAP_TYPE = std::map<std::tuple<int,enums::MODEL_TYPE>,std::tuple<enums::ENCODER_TYPE,std::string>>; |
|
|
|
|
|
class CONDITIONAL_REP_GRAPH { |
|
|
public: |
|
|
|
|
|
virtual ~CONDITIONAL_REP_GRAPH() {} |
|
|
|
|
|
virtual bool is_active(midi::TRACK_TYPE track_type) { |
|
|
throw std::runtime_error("CONDITIONAL_REP_GRAPH::is_active() : NOT IMPLEMENTED"); |
|
|
} |
|
|
|
|
|
template <typename T> |
|
|
void show(const std::vector<T> &x, const std::string &s) { |
|
|
std::ostringstream buffer; |
|
|
buffer << s << " :: ["; |
|
|
for (const auto &e : x) { |
|
|
buffer << util_protobuf::enum_to_string(e) << ", "; |
|
|
} |
|
|
buffer << "]"; |
|
|
data_structures::LOGGER(buffer.str()); |
|
|
} |
|
|
|
|
|
midi::TOKEN_TYPE possibly_skip(midi::TRACK_TYPE track_type, int last_token, const std::unique_ptr<REP_GRAPH> &rg, std::shared_ptr<encoder::REPRESENTATION> &rep, std::vector<int> & mask) { |
|
|
|
|
|
auto target_node = std::make_tuple(midi::TOKEN_NONE, 0); |
|
|
|
|
|
if (is_active(track_type)) { |
|
|
auto inferred_node = rg->graph.infer_node(last_token, rg->enc); |
|
|
auto next_nodes = graph->graph.get_next_nodes(inferred_node); |
|
|
auto next_global_nodes = rg->graph.get_next_nodes(inferred_node); |
|
|
|
|
|
if ((next_nodes.size() == 1)) { |
|
|
int loop_count = 0; |
|
|
while ((next_global_nodes.size() == 1) && (next_nodes[0] != next_global_nodes[0])) { |
|
|
target_node = next_global_nodes[0]; |
|
|
next_global_nodes = rg->graph.get_next_nodes(target_node); |
|
|
loop_count++; |
|
|
|
|
|
if (loop_count > 100) { |
|
|
throw std::runtime_error("CONDITIONAL_REP_GRAPH::possibly_skip() : INFINITE LOOP"); |
|
|
} |
|
|
} |
|
|
|
|
|
if (std::get<0>(target_node) != midi::TOKEN_NONE) { |
|
|
auto msg = data_structures::to_str("CONDITIONAL_REP_GRAPH::possibly_skip() : skip ", util_protobuf::enum_to_string(std::get<0>(target_node)), std::get<1>(target_node)); |
|
|
std::cout << msg << std::endl; |
|
|
data_structures::LOGGER(msg); |
|
|
rg->graph.skip(rg->graph.get_previous_nodes(target_node)[0]); |
|
|
rg->set_mask(rep->encode(std::get<0>(target_node), 0), mask); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
} |
|
|
return std::get<0>(target_node); |
|
|
} |
|
|
|
|
|
std::unique_ptr<REP_GRAPH> graph; |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
class INSTRUMENT_CONDITIONAL_REP_GRAPH : public CONDITIONAL_REP_GRAPH { |
|
|
public: |
|
|
|
|
|
INSTRUMENT_CONDITIONAL_REP_GRAPH(encoder::ENCODER *e, enums::MODEL_TYPE mt) { |
|
|
graph = std::make_unique<REP_GRAPH>(e,mt,encoder::get_drum_exclusive_token_types()); |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "INSTRUMENT_CONDITIONAL_REP_GRAPH" ); |
|
|
} |
|
|
|
|
|
bool is_active(midi::TRACK_TYPE track_type) { |
|
|
return (data_structures::is_drum_track(track_type) == false); |
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
class DRUM_CONDITIONAL_REP_GRAPH : public CONDITIONAL_REP_GRAPH { |
|
|
public: |
|
|
|
|
|
DRUM_CONDITIONAL_REP_GRAPH(encoder::ENCODER *e, enums::MODEL_TYPE mt) { |
|
|
graph = std::make_unique<REP_GRAPH>(e,mt,encoder::get_instrument_exclusive_token_types()); |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "DRUM_CONDITIONAL_REP_GRAPH" ); |
|
|
} |
|
|
|
|
|
bool is_active(midi::TRACK_TYPE track_type) { |
|
|
return data_structures::is_drum_track(track_type); |
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
class SAMPLE_CONTROL { |
|
|
public: |
|
|
SAMPLE_CONTROL(midi::Piece *piece, midi::Status *status, midi::HyperParam *param, midi::ModelMetadata *meta) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "SAMPLE_CONTROL" ); |
|
|
|
|
|
verbose = param->verbose(); |
|
|
|
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, util_protobuf::protobuf_to_string(status)); |
|
|
|
|
|
initialize(piece, status, param, meta); |
|
|
rep = enc->rep; |
|
|
rg = std::make_unique<REP_GRAPH>(enc.get(), model_type); |
|
|
instrument_rg = std::make_unique<INSTRUMENT_CONDITIONAL_REP_GRAPH>(enc.get(), model_type); |
|
|
drum_rg = std::make_unique<DRUM_CONDITIONAL_REP_GRAPH>(enc.get(), model_type); |
|
|
|
|
|
if ((!rg) || (!instrument_rg) || (!drum_rg)) { |
|
|
std::runtime_error("REP GRAPH CONSTRUCTOR FAILED"); |
|
|
} |
|
|
else { |
|
|
data_structures::LOGGER( "REP GRAPH CONSTRUCTOR SUCCESS" ); |
|
|
} |
|
|
|
|
|
parse_status(status); |
|
|
initialize_members(); |
|
|
|
|
|
} |
|
|
|
|
|
~SAMPLE_CONTROL() {} |
|
|
|
|
|
void initialize_members() { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "initialize_members" ); |
|
|
barlength = 4 * enc->config->resolution; |
|
|
timestep = 0; |
|
|
absolute_timestep = 0; |
|
|
bar_count = 0; |
|
|
track_count = 0; |
|
|
infill_bar_count = 0; |
|
|
finished = false; |
|
|
token_position = 0; |
|
|
num_delta_tokens = 0; |
|
|
} |
|
|
|
|
|
void set_bar_infill_prompt(std::vector<std::tuple<int,int>> &bars, midi::Piece *p, midi::Status *status, midi::HyperParam *param) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "set_bar_infill_prompt" ); |
|
|
|
|
|
if (p) { |
|
|
std::set<std::tuple<int,int>> barset; |
|
|
std::copy(bars.begin(), bars.end(), std::inserter(barset, barset.end())); |
|
|
enc->config->do_multi_fill = true; |
|
|
enc->config->multi_fill = barset; |
|
|
|
|
|
if (param->internal_skip_preprocess()) { |
|
|
util_protobuf::calculate_note_durations(p); |
|
|
prompt = enc->encode_wo_preprocess(p); |
|
|
} |
|
|
else { |
|
|
prompt = enc->encode(p); |
|
|
} |
|
|
|
|
|
data_structures::LOGGER( "FULL PROMPT " ); |
|
|
for (int i=0; i<(int)prompt.size(); i++) { |
|
|
data_structures::LOGGER( enc->rep->pretty(prompt[i]) ); |
|
|
} |
|
|
data_structures::LOGGER( "FULL PROMPT " ); |
|
|
|
|
|
int fill_start = enc->rep->encode(midi::TOKEN_FILL_IN_START,0); |
|
|
for (int index=0; index<(int)prompt.size(); index++) { |
|
|
if (prompt[index] == fill_start) { |
|
|
prompt.resize(index+1); |
|
|
break; |
|
|
} |
|
|
} |
|
|
} |
|
|
else { |
|
|
throw std::runtime_error("MUST PROVIDE midi::Piece FOR BAR INFILL MODE"); |
|
|
} |
|
|
} |
|
|
|
|
|
void set_autoregressive_prompt(std::vector<midi::StatusTrack> &tracks, midi::Piece *p, midi::Status *status, midi::HyperParam *param) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "set_autoregressive_prompt" ); |
|
|
|
|
|
enc->config->do_multi_fill = false; |
|
|
|
|
|
if (p->tracks_size()) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "SET AUTOREGRESSIVE PROMPT" ); |
|
|
prompt = enc->encode(p); |
|
|
} |
|
|
else { |
|
|
prompt.push_back( enc->rep->encode(midi::TOKEN_PIECE_START,0) ); |
|
|
} |
|
|
} |
|
|
|
|
|
void initialize(midi::Piece *piece, midi::Status *status, midi::HyperParam *param, midi::ModelMetadata *meta) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "initialize" ); |
|
|
|
|
|
util_protobuf::UpdateHasNotes(piece); |
|
|
|
|
|
std::vector<midi::StatusTrack> tracks; |
|
|
std::vector<std::tuple<int,int>> bars; |
|
|
int num_cond_tracks = 0; |
|
|
int num_resample_tracks = 0; |
|
|
int num_infill_tracks = 0; |
|
|
std::vector<util_protobuf::STATUS_TRACK_TYPE> track_types; |
|
|
std::vector<int> order; |
|
|
std::vector<int> cond_tracks; |
|
|
|
|
|
int track_num = 0; |
|
|
for (const auto &track : status->tracks()) { |
|
|
util_protobuf::STATUS_TRACK_TYPE tt = util_protobuf::infer_track_type(track); |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("STATUS TRACK TYPE FOR ",track.track_id()," : ", tt)); |
|
|
switch( tt ) { |
|
|
case util_protobuf::CONDITION: |
|
|
order.push_back( num_cond_tracks ); |
|
|
cond_tracks.push_back( track.track_id() ); |
|
|
num_cond_tracks++; |
|
|
break; |
|
|
case util_protobuf::RESAMPLE: |
|
|
order.push_back( num_resample_tracks ); |
|
|
tracks.push_back( track ); |
|
|
num_resample_tracks++; |
|
|
break; |
|
|
case util_protobuf::INFILL : |
|
|
num_infill_tracks++; |
|
|
break; |
|
|
} |
|
|
track_types.push_back( tt ); |
|
|
int bar_num = 0; |
|
|
for (const auto &selected : track.selected_bars()) { |
|
|
if (selected) { |
|
|
bars.push_back( std::make_pair(track_num, bar_num) ); |
|
|
} |
|
|
bar_num++; |
|
|
} |
|
|
track_num++; |
|
|
} |
|
|
|
|
|
|
|
|
int verbose_track_num = 0; |
|
|
for (const auto &track_type : track_types) { |
|
|
data_structures::LOGGER(data_structures::to_str("TRACK ", verbose_track_num, " -> ", track_type)); |
|
|
verbose_track_num++; |
|
|
} |
|
|
|
|
|
|
|
|
int nb = status->tracks(0).selected_bars_size(); |
|
|
|
|
|
enc = enums::getEncoderFromString(meta->encoder()); |
|
|
|
|
|
if (num_infill_tracks > 0) { |
|
|
data_structures::LOGGER( "INFILL" ); |
|
|
model_type = enums::BAR_INFILL_MODEL; |
|
|
|
|
|
|
|
|
util_protobuf::prune_tracks( |
|
|
piece, arange(0,piece->tracks_size(),1), arange(0,nb,1)); |
|
|
|
|
|
|
|
|
inverse_order = arange(piece->tracks_size()); |
|
|
data_structures::LOGGER(data_structures::to_str("GENERATING ", bars.size(), " BARS")); |
|
|
set_bar_infill_prompt(bars, piece, status, param); |
|
|
|
|
|
} |
|
|
else { |
|
|
data_structures::LOGGER( "TRACK" ); |
|
|
model_type = enums::TRACK_MODEL; |
|
|
|
|
|
data_structures::LOGGER(data_structures::to_str("GENERATING ", num_resample_tracks, " TRACKS")); |
|
|
|
|
|
|
|
|
|
|
|
for (track_num=0; track_num<status->tracks_size(); track_num++) { |
|
|
if (track_types[track_num] == util_protobuf::RESAMPLE) { |
|
|
order[track_num] = order[track_num] + num_cond_tracks; |
|
|
} |
|
|
} |
|
|
inverse_order.resize(order.size()); |
|
|
for (int i=0; i<(int)order.size(); i++) { |
|
|
inverse_order[order[i]] = i; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
util_protobuf::prune_tracks(piece, cond_tracks, arange(0,nb,1)); |
|
|
|
|
|
data_structures::LOGGER( "AFTER PRUNE TRACKS ...." ); |
|
|
util_protobuf::print_piece_summary(piece); |
|
|
data_structures::LOGGER( "============================" ); |
|
|
|
|
|
set_autoregressive_prompt(tracks, piece, status, param); |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
void finalize(midi::Piece *piece) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "finalize" ); |
|
|
if (model_type == enums::TRACK_MODEL) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "Reordering tracks" ); |
|
|
util_protobuf::reorder_tracks(piece, inverse_order); |
|
|
} |
|
|
} |
|
|
|
|
|
void parse_status(midi::Status *status) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "parse_status" ); |
|
|
|
|
|
|
|
|
|
|
|
int tnum = 0; |
|
|
num_infill_bars = 0; |
|
|
for (const auto &track : status->tracks()) { |
|
|
int bnum = 0; |
|
|
for (const auto &bar : track.selected_bars()) { |
|
|
|
|
|
selected_bars.push_back( std::pair(tnum,bnum) ); |
|
|
num_infill_bars += (int)bar; |
|
|
bnum++; |
|
|
} |
|
|
tnum++; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_bars = status->tracks(0).selected_bars_size(); |
|
|
num_tracks = status->tracks_size(); |
|
|
|
|
|
for (int i=0; i<status->tracks_size(); i++) { |
|
|
midi::StatusTrack track = status->tracks(inverse_order[i]); |
|
|
std::vector<int> mask(rep->max_token(),0); |
|
|
|
|
|
|
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("TRACK: ", i, " - POLYPHONY HARD LIMIT: ", track.polyphony_hard_limit())); |
|
|
polyphony_hard_limits.push_back( track.polyphony_hard_limit() ); |
|
|
|
|
|
|
|
|
track_temperatures.push_back( track.temperature() ); |
|
|
|
|
|
|
|
|
rep->set_mask(midi::TOKEN_NUM_BARS, {track.selected_bars_size()}, mask, 1); |
|
|
|
|
|
|
|
|
int tt = track.track_type(); |
|
|
if (tt == midi::STANDARD_BOTH) { |
|
|
rep->set_mask(midi::TOKEN_TRACK, {-1}, mask, 1); |
|
|
} |
|
|
else { |
|
|
rep->set_mask(midi::TOKEN_TRACK, {tt}, mask, 1); |
|
|
} |
|
|
|
|
|
|
|
|
std::vector<int> insts = enums::GM_MOD[track.instrument()]; |
|
|
if ((enc->rep->has_pretrain_instrument_mapping()) && (tt == midi::STANDARD_TRACK)) { |
|
|
|
|
|
for (int i=0; i<(int)insts.size(); i++) { |
|
|
auto it = enums::PRETRAIN_GROUPING_V2.find(insts[i]); |
|
|
if (it != enums::PRETRAIN_GROUPING_V2.end()) { |
|
|
insts[i] = it->second; |
|
|
} |
|
|
else { |
|
|
throw std::runtime_error("CAN NOT FIND INSTRUMENT IN PRETRAIN GROUPING"); |
|
|
} |
|
|
} |
|
|
} |
|
|
rep->set_mask(midi::TOKEN_INSTRUMENT, insts, mask, 1); |
|
|
|
|
|
|
|
|
rep->set_mask(midi::TOKEN_DENSITY_LEVEL, {track.density()-1}, mask, 1); |
|
|
|
|
|
|
|
|
rep->set_mask(midi::TOKEN_MIN_POLYPHONY, {track.min_polyphony_q()-1}, mask, 1); |
|
|
rep->set_mask(midi::TOKEN_MAX_POLYPHONY, {track.max_polyphony_q()-1}, mask, 1); |
|
|
|
|
|
|
|
|
rep->set_mask( |
|
|
midi::TOKEN_MIN_NOTE_DURATION, {track.min_note_duration_q()-1}, mask, 1); |
|
|
rep->set_mask( |
|
|
midi::TOKEN_MAX_NOTE_DURATION, {track.max_note_duration_q()-1}, mask, 1); |
|
|
|
|
|
set_track_masks(rep, mask, &track); |
|
|
|
|
|
std::set<midi::TOKEN_TYPE> fixed = { |
|
|
midi::TOKEN_NUM_BARS, |
|
|
midi::TOKEN_TRACK, |
|
|
midi::TOKEN_GENRE, |
|
|
midi::TOKEN_INSTRUMENT, |
|
|
midi::TOKEN_TIME_SIGNATURE, |
|
|
|
|
|
midi::TOKEN_DENSITY_LEVEL, |
|
|
midi::TOKEN_MIN_POLYPHONY, |
|
|
midi::TOKEN_MAX_POLYPHONY, |
|
|
midi::TOKEN_MIN_NOTE_DURATION, |
|
|
midi::TOKEN_MAX_NOTE_DURATION, |
|
|
|
|
|
}; |
|
|
|
|
|
auto track_control_tokens = encoder::get_track_attribute_control_graph(); |
|
|
for (int i=0; i<(int)track_control_tokens.size()-1; i++) { |
|
|
if (fixed.find(track_control_tokens[i]) == fixed.end()) { |
|
|
fixed.insert(track_control_tokens[i]); |
|
|
} |
|
|
} |
|
|
|
|
|
if (verbose) { |
|
|
|
|
|
data_structures::LOGGER( "=======================" ); |
|
|
data_structures::LOGGER( "ATTRIBUTE MASK : " ); |
|
|
for (int i=0; i<(int)mask.size(); i++) { |
|
|
if (mask[i]) { |
|
|
data_structures::LOGGER( rep->pretty(i) ); |
|
|
} |
|
|
} |
|
|
data_structures::LOGGER( "=======================" ); |
|
|
} |
|
|
|
|
|
for (const auto &kv : rep->token_domains) { |
|
|
if (fixed.find(kv.first) == fixed.end()) { |
|
|
rep->set_mask(kv.first, {-1}, mask, 1); |
|
|
} |
|
|
} |
|
|
|
|
|
attribute_masks.push_back( mask ); |
|
|
|
|
|
|
|
|
std::vector<std::vector<int>> bar_masks; |
|
|
for (int bn=0; bn<track.bars_size(); bn++) { |
|
|
std::vector<int> bar_mask(mask); |
|
|
midi::StatusBar bar = track.bars(bn); |
|
|
if (rep->has_token_type(midi::TOKEN_TIME_SIGNATURE)) { |
|
|
int tstoken = rep->encode(midi::TOKEN_TIME_SIGNATURE, std::make_tuple(bar.ts_numerator(), bar.ts_denominator())); |
|
|
bar_mask[tstoken] = 1; |
|
|
} |
|
|
set_bar_masks(rep, bar_mask, &bar); |
|
|
bar_masks.push_back( bar_mask ); |
|
|
} |
|
|
attribute_bar_masks.push_back( bar_masks ); |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
void update(int token) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "controlhSAMPLECONTROL update" ); |
|
|
midi::TOKEN_TYPE tt = rep->get_token_type(token); |
|
|
switch (tt) { |
|
|
case midi::TOKEN_TRACK: { |
|
|
bar_count = 0; |
|
|
absolute_timestep = 0; |
|
|
bar_start_timestep = 0; |
|
|
onsets.clear(); |
|
|
note_expiry.clear(); |
|
|
current_track_type = static_cast<midi::TRACK_TYPE>(rep->decode(token)); |
|
|
break; |
|
|
} |
|
|
case midi::TOKEN_TRACK_END: { |
|
|
track_count += 1; |
|
|
break; |
|
|
} |
|
|
case midi::TOKEN_BAR: { |
|
|
timestep = 0; |
|
|
barlength = 4 * enc->config->resolution; |
|
|
absolute_timestep = bar_start_timestep; |
|
|
break; |
|
|
} |
|
|
case midi::TOKEN_BAR_END: { |
|
|
bar_count += 1; |
|
|
bar_start_timestep += barlength; |
|
|
break; |
|
|
} |
|
|
case midi::TOKEN_FILL_IN_START: { |
|
|
|
|
|
|
|
|
|
|
|
int fillp_token = enc->rep->encode(midi::TOKEN_FILL_IN_PLACEHOLDER,0); |
|
|
auto it = history.begin(); |
|
|
auto prev = it; |
|
|
for (int i=0; i<=infill_bar_count; i++) { |
|
|
prev = it; |
|
|
it = find(next(it), history.end(), fillp_token); |
|
|
} |
|
|
for (auto i=next(prev); i!=it; i++) { |
|
|
if (verbose) { |
|
|
data_structures::LOGGER(data_structures::to_str("BACKFILLING :: ", enc->rep->pretty(*i))); |
|
|
} |
|
|
update(*i); |
|
|
} |
|
|
break; |
|
|
} |
|
|
case midi::TOKEN_FILL_IN_END: { |
|
|
infill_bar_count += 1; |
|
|
break; |
|
|
} |
|
|
case midi::TOKEN_TIME_SIGNATURE: { |
|
|
std::tuple<int,int> ts = rep->decode_timesig(token); |
|
|
double ts_ratio = ((double)std::get<0>(ts) / std::get<1>(ts)); |
|
|
barlength = ts_ratio * 4 * enc->config->resolution; |
|
|
break; |
|
|
} |
|
|
case midi::TOKEN_TIME_ABSOLUTE_POS: { |
|
|
int t = rep->decode(token); |
|
|
timestep = t; |
|
|
absolute_timestep = bar_start_timestep + t; |
|
|
break; |
|
|
} |
|
|
case midi::TOKEN_DELTA: { |
|
|
break; |
|
|
} |
|
|
case midi::TOKEN_NOTE_ONSET: { |
|
|
int pitch = rep->decode(token); |
|
|
onsets.insert( pitch ); |
|
|
|
|
|
if (data_structures::is_drum_track(current_track_type)) { |
|
|
|
|
|
last_token = token; |
|
|
update(rep->encode(midi::TOKEN_NOTE_DURATION,0)); |
|
|
} |
|
|
|
|
|
break; |
|
|
} |
|
|
case midi::TOKEN_NOTE_DURATION: { |
|
|
int dur = rep->decode(token) + 1; |
|
|
int pitch = rep->decode(last_token); |
|
|
note_expiry[dur + absolute_timestep].push_back(pitch); |
|
|
break; |
|
|
} |
|
|
default: |
|
|
break; |
|
|
} |
|
|
|
|
|
|
|
|
std::vector<int> to_remove; |
|
|
for (const auto &kv : note_expiry) { |
|
|
if (kv.first <= absolute_timestep) { |
|
|
for (const auto &pitch : kv.second) { |
|
|
onsets.erase( pitch ); |
|
|
} |
|
|
to_remove.push_back( kv.first ); |
|
|
} |
|
|
} |
|
|
|
|
|
for (const auto &t : to_remove) { |
|
|
note_expiry.erase( t ); |
|
|
} |
|
|
|
|
|
last_token = token; |
|
|
|
|
|
if (verbose) { |
|
|
data_structures::LOGGER(data_structures::to_str("ONSETS : ", onsets.size())); |
|
|
} |
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
void set_mask(int last_token, std::vector<int> &mask) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "controlhSAMPLECONTROL set_mask" ); |
|
|
|
|
|
|
|
|
midi::TOKEN_TYPE last_tt = rep->get_token_type(last_token); |
|
|
bool is_drum = data_structures::is_drum_track(current_track_type); |
|
|
|
|
|
|
|
|
midi::TOKEN_TYPE inst_skip = instrument_rg->possibly_skip(current_track_type, last_token, rg, rep, mask); |
|
|
midi::TOKEN_TYPE drum_skip = drum_rg->possibly_skip(current_track_type, last_token, rg, rep, mask); |
|
|
|
|
|
if ((is_drum) && (last_tt == midi::TOKEN_NOTE_ONSET)) { |
|
|
|
|
|
rg->skip(midi::TOKEN_NOTE_ONSET); |
|
|
rg->set_mask(rep->encode(midi::TOKEN_NOTE_DURATION,0), mask); |
|
|
} |
|
|
else if ((inst_skip == midi::TOKEN_NONE) && (drum_skip == midi::TOKEN_NONE)) { |
|
|
rg->set_mask(last_token, mask); |
|
|
} |
|
|
|
|
|
|
|
|
for (const auto &pitch : onsets) { |
|
|
mask[rep->encode(midi::TOKEN_NOTE_ONSET,pitch)] = 0; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (timestep == barlength) { |
|
|
if (verbose) { |
|
|
data_structures::LOGGER( "HIT TIME LIMIT >>>>>>>>>>>>>>>>>>>> " ); |
|
|
} |
|
|
rep->set_mask(midi::TOKEN_NOTE_ONSET, {-1}, mask, 0); |
|
|
rep->set_mask(midi::TOKEN_VELOCITY_LEVEL, {-1}, mask, 0); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
int hard_limit = 0; |
|
|
if (model_type == enums::TRACK_MODEL) { |
|
|
int index = std::min(track_count, num_tracks-1); |
|
|
hard_limit = polyphony_hard_limits[index]; |
|
|
} |
|
|
else { |
|
|
int index = std::min(infill_bar_count, num_infill_bars-1); |
|
|
hard_limit = polyphony_hard_limits[std::get<0>(selected_bars[index])]; |
|
|
} |
|
|
|
|
|
|
|
|
if ((int)onsets.size() >= hard_limit) { |
|
|
data_structures::LOGGER(data_structures::to_str("HIT HARD LIMIT ( ",(int)onsets.size()," >= ", hard_limit, " ) >>>>>>>>>>>>>>>>>>>> ")); |
|
|
rep->set_mask(midi::TOKEN_NOTE_ONSET, {-1}, mask, 0); |
|
|
rep->set_mask(midi::TOKEN_VELOCITY_LEVEL, {-1}, mask, 0); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
int delta_domain_limit = rep->get_domain_size(midi::TOKEN_DELTA); |
|
|
|
|
|
if (!enc->config->use_microtiming) { |
|
|
for (int td=0; td<delta_domain_limit; td++) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("NOT USE MICRO -> ","MASKING DELTA :: ", td)); |
|
|
mask[rep->encode(midi::TOKEN_DELTA,td)] = 0; |
|
|
} |
|
|
} else { |
|
|
if (last_tt == midi::TOKEN_DELTA) { |
|
|
num_delta_tokens += 1; |
|
|
mask[rep->encode(midi::TOKEN_DELTA_DIRECTION,0)] = 0; |
|
|
} else { |
|
|
num_delta_tokens = 0; |
|
|
} |
|
|
|
|
|
if (num_delta_tokens > 0) { |
|
|
for (int td=0; td<delta_domain_limit; td++) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("MAX MULTIPLE -> ","MASKING DELTA :: ", td)); |
|
|
mask[rep->encode(midi::TOKEN_DELTA,td)] = 0; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (delta_domain_limit) { |
|
|
int max_step = enc->config->step_to_delta(barlength - timestep, enc->config->resolution); |
|
|
if ((last_tt == midi::TOKEN_DELTA) && (rep->decode(last_token) == 0)){ |
|
|
max_step = enc->config->step_to_delta(timestep, enc->config->resolution); |
|
|
} |
|
|
int max_td = std::max(std::min(max_step, delta_domain_limit), 0); |
|
|
for (int td=max_td; td<delta_domain_limit; td++) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("MAX FORWARD/BACKWARD -> ","MASKING DELTA :: ", td)); |
|
|
mask[rep->encode(midi::TOKEN_DELTA,td)] = 0; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (timestep == 0) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("AT START -> ","MASKING DELTA DIRECTION")); |
|
|
mask[rep->encode(midi::TOKEN_DELTA_DIRECTION,0)] = 0; |
|
|
} |
|
|
|
|
|
|
|
|
if (timestep == barlength) { |
|
|
for (int td=1; td<delta_domain_limit; td++) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("AT END -> ","MASKING DELTA :: ", 0)); |
|
|
mask[rep->encode(midi::TOKEN_DELTA,td)] = 0; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
int domain_limit = rep->get_domain_size(midi::TOKEN_TIME_ABSOLUTE_POS); |
|
|
if (domain_limit) { |
|
|
for (int td=0; td<=timestep; td++) { |
|
|
mask[rep->encode(midi::TOKEN_TIME_ABSOLUTE_POS,td)] = 0; |
|
|
} |
|
|
for (int td=barlength+1; td<domain_limit; td++) { |
|
|
mask[rep->encode(midi::TOKEN_TIME_ABSOLUTE_POS,td)] = 0; |
|
|
} |
|
|
} |
|
|
|
|
|
if (model_type == enums::TRACK_MODEL) { |
|
|
|
|
|
if (bar_count != num_bars) { |
|
|
rep->set_mask(midi::TOKEN_TRACK_END, {-1}, mask, 0); |
|
|
} |
|
|
else { |
|
|
rep->set_mask(midi::TOKEN_BAR, {-1}, mask, 0); |
|
|
} |
|
|
|
|
|
if (track_count >= num_tracks) { |
|
|
std::fill(mask.begin(), mask.end(), 0); |
|
|
finished = true; |
|
|
} |
|
|
|
|
|
|
|
|
if (!finished) { |
|
|
for (int i=0; i<(int)mask.size(); i++) { |
|
|
int num_bars = attribute_bar_masks[track_count].size(); |
|
|
int safe_bar_index = std::min(bar_count, num_bars - 1); |
|
|
mask[i] *= attribute_bar_masks[track_count][safe_bar_index][i]; |
|
|
} |
|
|
} |
|
|
} |
|
|
else if ((model_type == enums::BAR_INFILL_MODEL)) { |
|
|
|
|
|
if (infill_bar_count >= num_infill_bars) { |
|
|
std::fill(mask.begin(), mask.end(), 0); |
|
|
finished = true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if ((std::find(mask.begin(), mask.end(), 1) == mask.end()) && (!finished)) { |
|
|
throw std::runtime_error("FATAL ERROR : EVERY TOKEN IS MASKED"); |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<int> get_mask(std::vector<int> &tokens) { |
|
|
data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "get_mask" ); |
|
|
std::vector<int> mask(enc->rep->max_token(), 0); |
|
|
for (int t=token_position; t<(int)tokens.size(); t++) { |
|
|
if (verbose) { |
|
|
data_structures::LOGGER(data_structures::to_str("UPDATING [", token_position, "] :: ", enc->rep->pretty(tokens[t]))); |
|
|
} |
|
|
update( tokens[t] ); |
|
|
history.push_back( tokens[t] ); |
|
|
token_position++; |
|
|
} |
|
|
|
|
|
set_mask(tokens.back(), mask); |
|
|
return mask; |
|
|
} |
|
|
|
|
|
|
|
|
std::map<int,std::vector<int>> note_expiry; |
|
|
std::set<int> onsets; |
|
|
int last_token; |
|
|
int num_delta_tokens; |
|
|
midi::TRACK_TYPE current_track_type; |
|
|
|
|
|
int barlength; |
|
|
int timestep; |
|
|
int absolute_timestep; |
|
|
int bar_start_timestep; |
|
|
|
|
|
int bar_count; |
|
|
int track_count; |
|
|
int infill_bar_count; |
|
|
|
|
|
int num_bars; |
|
|
int num_tracks; |
|
|
int num_infill_bars; |
|
|
std::vector<std::vector<int>> attribute_masks; |
|
|
std::vector<std::vector<std::vector<int>>> attribute_bar_masks; |
|
|
|
|
|
std::vector<int> prompt; |
|
|
std::vector<int> inverse_order; |
|
|
std::string ckpt_path; |
|
|
|
|
|
int token_position; |
|
|
bool finished; |
|
|
enums::MODEL_TYPE model_type; |
|
|
std::vector<int> history; |
|
|
|
|
|
bool verbose; |
|
|
int polyphony_hard_limit; |
|
|
std::vector<int> polyphony_hard_limits; |
|
|
std::vector<float> track_temperatures; |
|
|
std::vector<std::pair<int,int>> selected_bars; |
|
|
|
|
|
std::unique_ptr<encoder::ENCODER> enc; |
|
|
std::shared_ptr<encoder::REPRESENTATION> rep; |
|
|
std::unique_ptr<REP_GRAPH> rg; |
|
|
std::unique_ptr<INSTRUMENT_CONDITIONAL_REP_GRAPH> instrument_rg; |
|
|
std::unique_ptr<DRUM_CONDITIONAL_REP_GRAPH> drum_rg; |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
} |