Spaces:
Runtime error
Runtime error
| import torch | |
| def widen_tensor(tensor, factor): | |
| """ | |
| Duplicate a tensor `factor` times along the batch dimension. | |
| """ | |
| if tensor.dim() == 0: | |
| return tensor | |
| shape = tensor.shape | |
| repeat_dims = [1, factor] + [1] * (len(shape) - 1) | |
| expanded = tensor.unsqueeze(1).repeat(*repeat_dims) | |
| new_shape = [shape[0] * factor] + list(shape[1:]) | |
| return expanded.view(*new_shape) | |
| def widen_data(actor, include_embeddings=True, include_projections=True): | |
| """ | |
| Expand the actor's data `sample_size` times to support beam sampling. | |
| """ | |
| def widen_attributes(obj): | |
| for name, value in obj.__dict__.items(): | |
| if isinstance(value, torch.Tensor) and value.dim() > 0: | |
| setattr(obj, name, widen_tensor(value, actor.sample_size)) | |
| widen_attributes(actor.fleet) | |
| widen_attributes(actor.graph) | |
| actor.log_probs = widen_tensor(actor.log_probs, actor.sample_size) | |
| if include_embeddings: | |
| actor.node_embeddings = widen_tensor(actor.node_embeddings, actor.sample_size) | |
| if include_projections: | |
| def widen_projection(x): | |
| if x.dim() > 3: | |
| # (heads, batch, graph, embed) | |
| y = x.unsqueeze(2).repeat(1, 1, actor.sample_size, 1, 1) | |
| return y.view(x.size(0), x.size(1) * actor.sample_size, x.size(2), x.size(3)) | |
| return widen_tensor(x, actor.sample_size) | |
| actor.node_projections = { | |
| key: widen_projection(value) for key, value in actor.node_projections.items() | |
| } | |
| def select_data(actor, index, include_embeddings=True, include_projections=True): | |
| """ | |
| Select a specific subset of data using `index` (e.g., for beam search pruning). | |
| """ | |
| def select_attributes(obj): | |
| for name, value in obj.__dict__.items(): | |
| if isinstance(value, torch.Tensor) and value.dim() > 0 and value.size(0) >= index.max().item(): | |
| setattr(obj, name, value[index]) | |
| select_attributes(actor.fleet) | |
| select_attributes(actor.graph) | |
| actor.log_probs = actor.log_probs[index] | |
| if include_embeddings: | |
| actor.node_embeddings = actor.node_embeddings[index] | |
| if include_projections: | |
| def select_projection(x): | |
| if x.dim() > 3: | |
| return x[:, index, :, :] | |
| return x[index] | |
| actor.node_projections = { | |
| key: select_projection(value) for key, value in actor.node_projections.items() | |
| } | |