Fibo-local / optimization.py
linoyts's picture
linoyts HF Staff
Update optimization.py
c654722 verified
raw
history blame
955 Bytes
"""
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import torch
P = ParamSpec('P')
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
@spaces.GPU(duration=1500)
def compile_transformer():
with spaces.aoti_capture(pipeline.transformer) as call:
pipeline(*args, **kwargs)
exported = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs,
)
return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
# pipeline.transformer.fuse_qkv_projections()
spaces.aoti_apply(compile_transformer(), pipeline.transformer)