arahrooh commited on
Commit
6edb5d8
·
1 Parent(s): 3ca9195

Fix: Add fallback for InferenceClient API compatibility

Browse files
Files changed (1) hide show
  1. app.py +52 -20
app.py CHANGED
@@ -193,20 +193,36 @@ class InferenceAPIBot:
193
  def generate_answer(self, prompt: str, **kwargs) -> str:
194
  """Generate answer using Inference API"""
195
  try:
196
- # Convert prompt to chat format
197
- messages = [{"role": "user", "content": prompt}]
 
 
 
198
 
199
- # Call Inference API
200
- completion = self.client.chat.completions.create(
201
- model=self.current_model,
202
- messages=messages,
203
- max_tokens=kwargs.get('max_new_tokens', 512),
204
- temperature=kwargs.get('temperature', 0.2),
205
- top_p=kwargs.get('top_p', 0.9),
206
- )
207
-
208
- answer = completion.choices[0].message.content
209
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  except Exception as e:
211
  logger.error(f"Error calling Inference API: {e}", exc_info=True)
212
  return f"Error generating answer: {str(e)}"
@@ -264,14 +280,30 @@ class InferenceAPIBot:
264
  ]
265
 
266
  # Call Inference API
267
- completion = self.client.chat.completions.create(
268
- model=self.current_model,
269
- messages=messages,
270
- max_tokens=512 if target_level in ["college", "doctoral"] else 384,
271
- temperature=0.4 if target_level in ["college", "doctoral"] else 0.3,
272
- )
273
 
274
- enhanced_answer = completion.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  # Clean the answer (same as bot.py)
276
  cleaned = self.bot._clean_readability_answer(enhanced_answer, target_level)
277
 
 
193
  def generate_answer(self, prompt: str, **kwargs) -> str:
194
  """Generate answer using Inference API"""
195
  try:
196
+ # Use text generation API (more reliable than chat.completions)
197
+ # The InferenceClient supports both formats, but text_generation is more stable
198
+ max_tokens = kwargs.get('max_new_tokens', 512)
199
+ temperature = kwargs.get('temperature', 0.2)
200
+ top_p = kwargs.get('top_p', 0.9)
201
 
202
+ # Try chat.completions first (newer API)
203
+ try:
204
+ messages = [{"role": "user", "content": prompt}]
205
+ completion = self.client.chat.completions.create(
206
+ model=self.current_model,
207
+ messages=messages,
208
+ max_tokens=max_tokens,
209
+ temperature=temperature,
210
+ top_p=top_p,
211
+ )
212
+ answer = completion.choices[0].message.content
213
+ return answer
214
+ except (AttributeError, TypeError) as e:
215
+ # Fallback to text generation API if chat.completions not available
216
+ logger.warning(f"chat.completions not available, using text_generation: {e}")
217
+ response = self.client.text_generation(
218
+ prompt,
219
+ model=self.current_model,
220
+ max_new_tokens=max_tokens,
221
+ temperature=temperature,
222
+ top_p=top_p,
223
+ return_full_text=False,
224
+ )
225
+ return response
226
  except Exception as e:
227
  logger.error(f"Error calling Inference API: {e}", exc_info=True)
228
  return f"Error generating answer: {str(e)}"
 
280
  ]
281
 
282
  # Call Inference API
283
+ max_tokens = 512 if target_level in ["college", "doctoral"] else 384
284
+ temperature = 0.4 if target_level in ["college", "doctoral"] else 0.3
 
 
 
 
285
 
286
+ try:
287
+ # Try chat.completions first
288
+ completion = self.client.chat.completions.create(
289
+ model=self.current_model,
290
+ messages=messages,
291
+ max_tokens=max_tokens,
292
+ temperature=temperature,
293
+ )
294
+ enhanced_answer = completion.choices[0].message.content
295
+ except (AttributeError, TypeError) as e:
296
+ # Fallback to text generation
297
+ logger.warning(f"chat.completions not available for readability, using text_generation: {e}")
298
+ # Combine system and user messages for text generation
299
+ combined_prompt = f"{system_message}\n\n{user_message}"
300
+ enhanced_answer = self.client.text_generation(
301
+ combined_prompt,
302
+ model=self.current_model,
303
+ max_new_tokens=max_tokens,
304
+ temperature=temperature,
305
+ return_full_text=False,
306
+ )
307
  # Clean the answer (same as bot.py)
308
  cleaned = self.bot._clean_readability_answer(enhanced_answer, target_level)
309