Spaces:
Sleeping
Sleeping
File size: 3,245 Bytes
afd99d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import inspect
import uuid
from abc import ABC, abstractmethod
from functools import wraps
import gradio as gr
from gradio_client import Client
__version__ = "0.1.0"
__all__ = ["Environment", "load", "register_env"]
class Environment(ABC):
@abstractmethod
def reset(self, *args, **kwargs):
pass
@abstractmethod
def step(self, *args, **kwargs):
pass
class _RemoteEnvironment(Environment):
def __init__(self, env_id: str):
username, repo = env_id.split("/")
self.client = Client(f"https://{username}-{repo}.hf.space/")
self.session_id = self.client.predict(api_name="/init")
def reset(self, *args, **kwargs):
return self.client.predict(self.session_id, api_name="/reset", *args, **kwargs)
def step(self, *args, **kwargs):
return self.client.predict(self.session_id, api_name="/step", *args, **kwargs)
def load(env_id: str) -> _RemoteEnvironment:
return _RemoteEnvironment(env_id)
def bind_method_to_session(method, registry: dict):
sig = inspect.signature(method)
params = list(sig.parameters.values())
@wraps(method)
def wrapper(session_id: str, *args, **kwargs):
instance = registry.get(session_id)
if instance is None:
raise ValueError(f"Invalid session_id: {session_id}")
m = getattr(instance, method.__func__.__name__)
return m(*args, **kwargs)
# --- update __annotations__ ---
wrapper.__annotations__ = method.__annotations__.copy()
wrapper.__annotations__["session_id"] = str
# --- build signature ---
new_params = (
inspect.Parameter(
"session_id",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=str,
),
*params,
)
wrapper.__signature__ = inspect.Signature(
parameters=new_params,
return_annotation=sig.return_annotation,
)
return wrapper
def register_env(env_cls: type[Environment]) -> gr.Blocks:
"""
Register an environment class with Gradio APIs.
Example:
```python
from environments import register_env, Environment
import gradio as gr
class MyEnvironmentClass(Environment):
def reset(self) -> str:
return "Reset called!"
def step(self, action: str) -> str:
return f"Step called with action: {action}!"
with gr.Blocks() as demo:
register_env(MyEnvironmentClass)
demo.launch(mcp_server_name=True)
```
"""
sessions = {}
def init_env() -> str:
"""
Initialize a new environment instance and return a session ID.
Returns:
A unique session ID for the new environment instance.
"""
session_id = str(uuid.uuid4())
env = env_cls()
sessions[session_id] = env
return session_id
# Bind methods to session dict
reset_api = bind_method_to_session(env_cls().reset, sessions)
step_api = bind_method_to_session(env_cls().step, sessions)
# Create Gradio APIs
gr.api(
init_env,
api_name="init",
api_description="Initialize a new environment session",
)
gr.api(reset_api, api_name="reset")
gr.api(step_api, api_name="step")
|