File size: 4,010 Bytes
3c746b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebb14dd
 
 
 
 
 
 
 
 
 
 
 
3c746b0
 
261aa57
 
 
 
 
 
ebb14dd
261aa57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4613216
261aa57
 
ebb14dd
261aa57
 
3c746b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget
from flask import Flask, request, jsonify
from flask_cors import CORS
import threading
import io
import base64

app = Flask(__name__)
CORS(app)

pipe = None
app.config['temp_response'] = None
app.config['generation_thread'] = None


# def initialize_model():
#     global pipe
#     try:
#       print('Downloading the model weights')
#       device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#       xm = load_model('transmitter', device=device)
#       model = load_model('text300M', device=device)
#       diffusion = diffusion_from_config(load_config('diffusion'))
#       return device, xm, model, diffusion
#     except Exception as e:
#       print(f"Error downloading the model: {e}")
#       return jsonify({"error": f"Failed to download model: {str(e)}"}), 500

def generate_image_gif(prompt):
    print('Downloading the model weights')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    xm = load_model('transmitter', device=device)
    model = load_model('text300M', device=device)
    diffusion = diffusion_from_config(load_config('diffusion'))

    try:
        batch_size = 1
        guidance_scale = 30.0
    
        latents = sample_latents(
            batch_size=batch_size,
            model=model,
            diffusion=diffusion,
            guidance_scale=guidance_scale,
            model_kwargs=dict(texts=[prompt] * batch_size),
            progress=True,
            clip_denoised=True,
            use_fp16=True,
            use_karras=True,
            karras_steps=64,
            sigma_min=1E-3,
            sigma_max=160,
            s_churn=0,
        )
        render_mode = 'nerf'
        size = 256
        # render_mode = 'nerf' # you can change this to 'stf'
        # size =  # this is the size of the renders, higher values take longer to render.
    
        cameras = create_pan_cameras(size, device)
        images = decode_latent_images(xm, latents, cameras, rendering_mode=render_mode)
        writer = io.BytesIO()
        images[0].save(writer, format="GIF", save_all=True, append_images=images[1:], duration=100, loop=0)
        writer.seek(0)
        data = base64.b64encode(writer.read()).decode("ascii")
        response_data = {'base64_3d': data,'status':None}
        print('response_data',response_data)
        return response_data
    except Exception as e:
        print(f"Error generating 3D: {e}")
        return jsonify({"error": f"Failed to generate 3D animation: {str(e)}"}), 500

def background(prompt):
  with app.app_context():
    data = generate_image_gif(prompt)
    app.config['temp_response'] = data

@app.route('/run', methods=['POST'])
def handle_animation_request():

    prompt = request.form.get('prompt')
    if prompt:
        generation_thread = threading.Thread(target=background, args=(prompt,))
        app.config['generation_thread'] = generation_thread
        generation_thread.start()
        response_data = {"message": "3D generation started", "process_id": generation_thread.ident}
    
        return jsonify(response_data)
    else:
      return jsonify({"message": "Please provide a valid text prompt."}), 400

@app.route('/status', methods=['GET'])
def check_animation_status():
    process_id = request.args.get('process_id',None)
    
    if process_id:
        generation_thread = app.config.get('generation_thread')
        if generation_thread and generation_thread.is_alive():
            return jsonify({"status": "in_progress"}), 200
        elif app.config.get('temp_response'):
            final_response = app.config['temp_response']
            final_response['status'] = 'completed'
            return jsonify(final_response)

if __name__ == '__main__':
    app.run(debug=True)