| | from transformers import PretrainedConfig |
| | from transformers.utils import logging |
| | from transformers.models.esm import EsmConfig |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class ProtSTConfig(PretrainedConfig): |
| | r""" |
| | This is the configuration class to store the configuration of a [`ProtSTModel`]. |
| | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
| | documentation from [`PretrainedConfig`] for more information. |
| | Args: |
| | protein_config (`dict`, *optional*): |
| | Dictionary of configuration options used to initialize [`EsmForProteinRepresentation`]. |
| | ```""" |
| |
|
| | model_type = "protst" |
| |
|
| | def __init__( |
| | self, |
| | protein_config=None, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| |
|
| | if protein_config is None: |
| | protein_config = {} |
| | logger.info("`protein_config` is `None`. Initializing the `ProtSTProteinConfig` with default values.") |
| |
|
| | self.protein_config = EsmConfig(**protein_config) |
| |
|
| | @classmethod |
| | def from_protein_text_configs( |
| | cls, protein_config: EsmConfig, **kwargs |
| | ): |
| | r""" |
| | Instantiate a [`ProtSTConfig`] (or a derived class) from ProtST text model configuration. Returns: |
| | [`ProtSTConfig`]: An instance of a configuration object |
| | """ |
| |
|
| | return cls(protein_config=protein_config.to_dict(), **kwargs) |