Spanicin commited on
Commit
3c746b0
·
verified ·
1 Parent(s): cf32278

Upload 2 files

Browse files
Files changed (2) hide show
  1. Dockerfile +27 -0
  2. app.py +108 -0
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9.13
2
+ USER root
3
+
4
+ RUN apt-get update && \
5
+ apt-get install -y --no-install-recommends \
6
+ libgl1-mesa-glx \
7
+ git \
8
+ && \
9
+ rm -rf /var/lib/apt/lists/*
10
+
11
+ RUN useradd -m -u 1000 user
12
+ USER user
13
+
14
+ ENV HOME=/home/user \
15
+ PATH=/home/user/.local/bin:$PATH
16
+
17
+ WORKDIR $HOME/app
18
+
19
+ COPY --chown=user . $HOME/app
20
+
21
+ ENV NUMBA_CACHE_DIR=/tmp/numba_cache
22
+ RUN git clone https://github.com/openai/shap-e
23
+ RUN pip install -r requirements.txt
24
+
25
+
26
+
27
+ CMD ["gunicorn", "-b", "0.0.0.0:7860","app:app"]
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from shap_e.diffusion.sample import sample_latents
3
+ from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
4
+ from shap_e.models.download import load_model, load_config
5
+ from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget
6
+ from flask import Flask, request, jsonify
7
+ from flask_cors import CORS
8
+ import threading
9
+ import io
10
+ import base64
11
+
12
+ app = Flask(__name__)
13
+ CORS(app)
14
+
15
+ pipe = None
16
+ app.config['temp_response'] = None
17
+ app.config['generation_thread'] = None
18
+
19
+
20
+ def initialize_model():
21
+ global pipe
22
+ try:
23
+ print('Downloading the model weights')
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ xm = load_model('transmitter', device=device)
26
+ model = load_model('text300M', device=device)
27
+ diffusion = diffusion_from_config(load_config('diffusion'))
28
+ return device, xm, model, diffusion
29
+ except Exception as e:
30
+ print(f"Error downloading the model: {e}")
31
+ return jsonify({"error": f"Failed to download model: {str(e)}"}), 500
32
+
33
+ def generate_image_gif(prompt):
34
+ global pipe
35
+ if pipe is None:
36
+ device, xm, model, diffusion = initialize_model()
37
+
38
+ try:
39
+ batch_size = 1
40
+ guidance_scale = 30.0
41
+
42
+ latents = sample_latents(
43
+ batch_size=batch_size,
44
+ model=model,
45
+ diffusion=diffusion,
46
+ guidance_scale=guidance_scale,
47
+ model_kwargs=dict(texts=[prompt] * batch_size),
48
+ progress=True,
49
+ clip_denoised=True,
50
+ use_fp16=True,
51
+ use_karras=True,
52
+ karras_steps=64,
53
+ sigma_min=1E-3,
54
+ sigma_max=160,
55
+ s_churn=0,
56
+ )
57
+ render_mode = 'nerf'
58
+ size = 256
59
+ # render_mode = 'nerf' # you can change this to 'stf'
60
+ # size = # this is the size of the renders, higher values take longer to render.
61
+
62
+ cameras = create_pan_cameras(size, device)
63
+ images = decode_latent_images(xm, latents, cameras, rendering_mode=render_mode)
64
+ writer = io.BytesIO()
65
+ images[0].save(writer, format="GIF", save_all=True, append_images=images[1:], duration=100, loop=0)
66
+ writer.seek(0)
67
+ data = base64.b64encode(writer.read()).decode("ascii")
68
+ response_data = {'video_base64': data,'status':None}
69
+ print('response_data',response_data)
70
+ return response_data
71
+ except Exception as e:
72
+ print(f"Error generating 3D: {e}")
73
+ return jsonify({"error": f"Failed to generate 3D animation: {str(e)}"}), 500
74
+
75
+ def background(prompt):
76
+ with app.app_context():
77
+ data = generate_image_gif(prompt)
78
+ app.config['temp_response'] = data
79
+
80
+ @app.route('/run', methods=['POST'])
81
+ def handle_animation_request():
82
+
83
+ prompt = request.form.get('prompt')
84
+ if prompt:
85
+ generation_thread = threading.Thread(target=background, args=(prompt,))
86
+ app.config['generation_thread'] = generation_thread
87
+ generation_thread.start()
88
+ response_data = {"message": "3D generation started", "process_id": generation_thread.ident}
89
+
90
+ return jsonify(response_data)
91
+ else:
92
+ return jsonify({"message": "Please provide a valid text prompt."}), 400
93
+
94
+ @app.route('/status', methods=['GET'])
95
+ def check_animation_status():
96
+ process_id = request.args.get('process_id',None)
97
+
98
+ if process_id:
99
+ generation_thread = app.config.get('generation_thread')
100
+ if generation_thread and generation_thread.is_alive():
101
+ return jsonify({"status": "in_progress"}), 200
102
+ elif app.config.get('temp_response'):
103
+ final_response = app.config['temp_response']
104
+ final_response['status'] = 'completed'
105
+ return jsonify(final_response)
106
+
107
+ if __name__ == '__main__':
108
+ app.run(debug=True)