|
|
import gradio as gr |
|
|
import requests |
|
|
import json |
|
|
import os |
|
|
from google.oauth2 import service_account |
|
|
from google.auth.transport.requests import Request |
|
|
|
|
|
|
|
|
PROJECT_ID = "white-dispatch-460720-h7" |
|
|
ENDPOINT_ID = "5591670836659486720" |
|
|
LOCATION = "us-central1" |
|
|
|
|
|
def query_vertex(user_input): |
|
|
|
|
|
service_account_key = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY") |
|
|
|
|
|
if not service_account_key: |
|
|
return "Error: GOOGLE_SERVICE_ACCOUNT_KEY environment variable not set. Please configure service account authentication in Hugging Face Spaces settings." |
|
|
|
|
|
try: |
|
|
|
|
|
import json as json_lib |
|
|
service_account_info = json_lib.loads(service_account_key) |
|
|
credentials = service_account.Credentials.from_service_account_info( |
|
|
service_account_info, |
|
|
scopes=['https://www.googleapis.com/auth/cloud-platform'] |
|
|
) |
|
|
credentials.refresh(Request()) |
|
|
token = credentials.token |
|
|
except Exception as e: |
|
|
return f"Error authenticating with service account: {str(e)}" |
|
|
|
|
|
url = f"https://{ENDPOINT_ID}.{LOCATION}-104808504044.prediction.vertexai.goog/v1/projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/{ENDPOINT_ID}:predict" |
|
|
|
|
|
payload = { |
|
|
"instances": [ |
|
|
{ |
|
|
"@requestFormat": "chatCompletions", |
|
|
"messages": [ |
|
|
{"role": "user", "content": user_input} |
|
|
], |
|
|
"max_tokens": 200 |
|
|
} |
|
|
] |
|
|
} |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {token}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
response = requests.post(url, headers=headers, data=json.dumps(payload)) |
|
|
try: |
|
|
result = response.json() |
|
|
|
|
|
if 'predictions' in result: |
|
|
predictions = result['predictions'] |
|
|
if 'choices' in predictions and len(predictions['choices']) > 0: |
|
|
choice = predictions['choices'][0] |
|
|
if 'message' in choice and 'content' in choice['message']: |
|
|
return choice['message']['content'] |
|
|
return f"Unexpected response format: {result}" |
|
|
except Exception as e: |
|
|
return f"Error: {response.text} - {str(e)}" |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=query_vertex, |
|
|
inputs=gr.Textbox(lines=4, placeholder="Enter your product description..."), |
|
|
outputs=gr.Textbox(lines=15, placeholder="HTS classification will appear here..."), |
|
|
title="Atlas HTS Classification (via Vertex AI)", |
|
|
description="Enter a product description and Atlas (LLaMA-3.3-70B fine-tuned) will return the HTS classification." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |