File size: 3,245 Bytes
afd99d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import inspect
import uuid
from abc import ABC, abstractmethod
from functools import wraps

import gradio as gr
from gradio_client import Client


__version__ = "0.1.0"
__all__ = ["Environment", "load", "register_env"]


class Environment(ABC):
    @abstractmethod
    def reset(self, *args, **kwargs):
        pass

    @abstractmethod
    def step(self, *args, **kwargs):
        pass


class _RemoteEnvironment(Environment):
    def __init__(self, env_id: str):
        username, repo = env_id.split("/")
        self.client = Client(f"https://{username}-{repo}.hf.space/")
        self.session_id = self.client.predict(api_name="/init")

    def reset(self, *args, **kwargs):
        return self.client.predict(self.session_id, api_name="/reset", *args, **kwargs)

    def step(self, *args, **kwargs):
        return self.client.predict(self.session_id, api_name="/step", *args, **kwargs)


def load(env_id: str) -> _RemoteEnvironment:
    return _RemoteEnvironment(env_id)


def bind_method_to_session(method, registry: dict):
    sig = inspect.signature(method)
    params = list(sig.parameters.values())

    @wraps(method)
    def wrapper(session_id: str, *args, **kwargs):
        instance = registry.get(session_id)
        if instance is None:
            raise ValueError(f"Invalid session_id: {session_id}")
        m = getattr(instance, method.__func__.__name__)
        return m(*args, **kwargs)

    # --- update __annotations__ ---
    wrapper.__annotations__ = method.__annotations__.copy()
    wrapper.__annotations__["session_id"] = str

    # --- build signature ---
    new_params = (
        inspect.Parameter(
            "session_id",
            kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
            annotation=str,
        ),
        *params,
    )
    wrapper.__signature__ = inspect.Signature(
        parameters=new_params,
        return_annotation=sig.return_annotation,
    )

    return wrapper


def register_env(env_cls: type[Environment]) -> gr.Blocks:
    """
    Register an environment class with Gradio APIs.

    Example:

    ```python
    from environments import register_env, Environment
    import gradio as gr

    class MyEnvironmentClass(Environment):
        def reset(self) -> str:
            return "Reset called!"

        def step(self, action: str) -> str:
            return f"Step called with action: {action}!"

    with gr.Blocks() as demo:
        register_env(MyEnvironmentClass)

    demo.launch(mcp_server_name=True)
    ```
    """
    sessions = {}

    def init_env() -> str:
        """
        Initialize a new environment instance and return a session ID.
        Returns:
            A unique session ID for the new environment instance.
        """
        session_id = str(uuid.uuid4())
        env = env_cls()
        sessions[session_id] = env
        return session_id

    # Bind methods to session dict
    reset_api = bind_method_to_session(env_cls().reset, sessions)
    step_api = bind_method_to_session(env_cls().step, sessions)

    # Create Gradio APIs
    gr.api(
        init_env,
        api_name="init",
        api_description="Initialize a new environment session",
    )
    gr.api(reset_api, api_name="reset")
    gr.api(step_api, api_name="step")