Spaces:
Runtime error
Runtime error
| import torch | |
| def widen_tensor(tensor, factor): | |
| """ | |
| Expands a tensor by repeating it along a new batch dimension. | |
| """ | |
| if tensor.ndim == 0: | |
| return tensor | |
| shape = list(tensor.shape) | |
| repeat_dims = [1, factor] + [1] * (tensor.ndim - 1) | |
| expanded = tensor.unsqueeze(1).repeat(*repeat_dims) | |
| new_shape = [shape[0] * factor] + shape[1:] | |
| return expanded.reshape(*new_shape) | |
| def widen_data(actor, include_embeddings=True, include_projections=True): | |
| """ | |
| Expands the actor's fleet, graph, and optionally embeddings/projections | |
| for use in beam search by repeating the batch dimension `sample_size` times. | |
| """ | |
| sample_size = actor.sample_size | |
| # Fleet tensors | |
| for name, tensor in vars(actor.fleet).items(): | |
| if isinstance(tensor, torch.Tensor) and tensor.ndim > 0: | |
| widened = widen_tensor(tensor, sample_size) | |
| setattr(actor.fleet, name, widened) | |
| # Graph tensors | |
| for name, tensor in vars(actor.graph).items(): | |
| if isinstance(tensor, torch.Tensor) and tensor.ndim > 0: | |
| widened = widen_tensor(tensor, sample_size) | |
| setattr(actor.graph, name, widened) | |
| actor.log_probs = widen_tensor(actor.log_probs, sample_size) | |
| if include_embeddings: | |
| actor.node_embeddings = widen_tensor(actor.node_embeddings, sample_size) | |
| if include_projections: | |
| def widen_projection(tensor, size): | |
| if tensor.ndim > 3: | |
| # Special case for shape: (n_heads, B, G, D) β (n_heads, B * size, G, D) | |
| tensor = tensor.unsqueeze(2).repeat(1, 1, size, 1, 1) | |
| return tensor.reshape(tensor.shape[0], tensor.shape[1] * size, tensor.shape[3], tensor.shape[4]) | |
| return widen_tensor(tensor, size) | |
| actor.node_projections = { | |
| key: widen_projection(tensor, sample_size) | |
| for key, tensor in actor.node_projections.items() | |
| } | |
| def select_data(actor, index, include_embeddings=True, include_projections=True): | |
| """ | |
| Selects a subset of the beam based on indices, usually used to keep top-k paths in beam search. | |
| """ | |
| index = index.long() | |
| max_index = index.max().item() | |
| # Select from fleet | |
| for name, tensor in vars(actor.fleet).items(): | |
| if isinstance(tensor, torch.Tensor) and tensor.shape[0] > max_index: | |
| setattr(actor.fleet, name, tensor[index]) | |
| # Select from graph | |
| for name, tensor in vars(actor.graph).items(): | |
| if isinstance(tensor, torch.Tensor) and tensor.shape[0] > max_index: | |
| setattr(actor.graph, name, tensor[index]) | |
| actor.log_probs = actor.log_probs[index] | |
| if include_embeddings: | |
| actor.node_embeddings = actor.node_embeddings[index] | |
| if include_projections: | |
| def select_projection(tensor): | |
| if tensor.ndim > 3: | |
| return tensor[:, index, :, :] | |
| return tensor[index] | |
| actor.node_projections = { | |
| key: select_projection(tensor) | |
| for key, tensor in actor.node_projections.items() | |
| } | |