fix: rag
This commit is contained in:
		
							parent
							
								
									bcc27e3852
								
							
						
					
					
						commit
						514c7f1520
					
				|  | @ -170,7 +170,9 @@ app.state.MODELS = {} | ||||||
| origins = ["*"] | origins = ["*"] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| async def get_function_call_response(messages, tool_id, template, task_model_id, user): | async def get_function_call_response( | ||||||
|  |     messages, files, tool_id, template, task_model_id, user | ||||||
|  | ): | ||||||
|     tool = Tools.get_tool_by_id(tool_id) |     tool = Tools.get_tool_by_id(tool_id) | ||||||
|     tools_specs = json.dumps(tool.specs, indent=2) |     tools_specs = json.dumps(tool.specs, indent=2) | ||||||
|     content = tools_function_calling_generation_template(template, tools_specs) |     content = tools_function_calling_generation_template(template, tools_specs) | ||||||
|  | @ -265,6 +267,13 @@ async def get_function_call_response(messages, tool_id, template, task_model_id, | ||||||
|                             "__messages__": messages, |                             "__messages__": messages, | ||||||
|                         } |                         } | ||||||
| 
 | 
 | ||||||
|  |                     if "__files__" in sig.parameters: | ||||||
|  |                         # Call the function with the '__files__' parameter included | ||||||
|  |                         params = { | ||||||
|  |                             **params, | ||||||
|  |                             "__files__": files, | ||||||
|  |                         } | ||||||
|  | 
 | ||||||
|                     function_result = function(**params) |                     function_result = function(**params) | ||||||
|                 except Exception as e: |                 except Exception as e: | ||||||
|                     print(e) |                     print(e) | ||||||
|  | @ -338,6 +347,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): | ||||||
|                     try: |                     try: | ||||||
|                         response = await get_function_call_response( |                         response = await get_function_call_response( | ||||||
|                             messages=data["messages"], |                             messages=data["messages"], | ||||||
|  |                             files=data.get("files", []), | ||||||
|                             tool_id=tool_id, |                             tool_id=tool_id, | ||||||
|                             template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, |                             template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, | ||||||
|                             task_model_id=task_model_id, |                             task_model_id=task_model_id, | ||||||
|  | @ -353,7 +363,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): | ||||||
| 
 | 
 | ||||||
|                 print(f"tool_context: {context}") |                 print(f"tool_context: {context}") | ||||||
| 
 | 
 | ||||||
|             # If docs field is present, generate RAG completions |             # TODO: Check if tools & functions have files support to skip this step to delegate file processing | ||||||
|  |             # If files field is present, generate RAG completions | ||||||
|             if "files" in data: |             if "files" in data: | ||||||
|                 data = {**data} |                 data = {**data} | ||||||
|                 rag_context, citations = get_rag_context( |                 rag_context, citations = get_rag_context( | ||||||
|  | @ -376,15 +387,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): | ||||||
|                 system_prompt = rag_template( |                 system_prompt = rag_template( | ||||||
|                     rag_app.state.config.RAG_TEMPLATE, context, prompt |                     rag_app.state.config.RAG_TEMPLATE, context, prompt | ||||||
|                 ) |                 ) | ||||||
| 
 |  | ||||||
|                 print(system_prompt) |                 print(system_prompt) | ||||||
| 
 |  | ||||||
|                 data["messages"] = add_or_update_system_message( |                 data["messages"] = add_or_update_system_message( | ||||||
|                     f"\n{system_prompt}", data["messages"] |                     f"\n{system_prompt}", data["messages"] | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|             modified_body_bytes = json.dumps(data).encode("utf-8") |             modified_body_bytes = json.dumps(data).encode("utf-8") | ||||||
| 
 |  | ||||||
|             # Replace the request body with the modified one |             # Replace the request body with the modified one | ||||||
|             request._body = modified_body_bytes |             request._body = modified_body_bytes | ||||||
|             # Set custom header to ensure content-length matches new body length |             # Set custom header to ensure content-length matches new body length | ||||||
|  | @ -961,7 +969,12 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ | ||||||
| 
 | 
 | ||||||
|     try: |     try: | ||||||
|         context = await get_function_call_response( |         context = await get_function_call_response( | ||||||
|             form_data["messages"], form_data["tool_id"], template, model_id, user |             form_data["messages"], | ||||||
|  |             form_data.get("files", []), | ||||||
|  |             form_data["tool_id"], | ||||||
|  |             template, | ||||||
|  |             model_id, | ||||||
|  |             user, | ||||||
|         ) |         ) | ||||||
|         return context |         return context | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|  |  | ||||||
|  | @ -587,22 +587,17 @@ | ||||||
| 		}); | 		}); | ||||||
| 
 | 
 | ||||||
| 		let files = []; | 		let files = []; | ||||||
| 
 |  | ||||||
| 		if (model?.info?.meta?.knowledge ?? false) { | 		if (model?.info?.meta?.knowledge ?? false) { | ||||||
| 			files = model.info.meta.knowledge; | 			files = model.info.meta.knowledge; | ||||||
| 		} | 		} | ||||||
| 
 | 		const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1); | ||||||
| 		files = [ | 		files = [ | ||||||
| 			...files, | 			...files, | ||||||
| 			...messages | 			...(lastUserMessage?.files?.filter((item) => | ||||||
| 				.filter((message) => message?.files ?? null) |  | ||||||
| 				.map((message) => |  | ||||||
| 					message.files.filter((item) => |  | ||||||
| 				['doc', 'file', 'collection', 'web_search_results'].includes(item.type) | 				['doc', 'file', 'collection', 'web_search_results'].includes(item.type) | ||||||
| 					) | 			) ?? []) | ||||||
| 				) |  | ||||||
| 				.flat(1) |  | ||||||
| 		].filter( | 		].filter( | ||||||
|  | 			// Remove duplicates | ||||||
| 			(item, index, array) => | 			(item, index, array) => | ||||||
| 				array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index | 				array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index | ||||||
| 		); | 		); | ||||||
|  | @ -832,22 +827,17 @@ | ||||||
| 		const responseMessage = history.messages[responseMessageId]; | 		const responseMessage = history.messages[responseMessageId]; | ||||||
| 
 | 
 | ||||||
| 		let files = []; | 		let files = []; | ||||||
| 
 |  | ||||||
| 		if (model?.info?.meta?.knowledge ?? false) { | 		if (model?.info?.meta?.knowledge ?? false) { | ||||||
| 			files = model.info.meta.knowledge; | 			files = model.info.meta.knowledge; | ||||||
| 		} | 		} | ||||||
| 
 | 		const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1); | ||||||
| 		files = [ | 		files = [ | ||||||
| 			...files, | 			...files, | ||||||
| 			...messages | 			...(lastUserMessage?.files?.filter((item) => | ||||||
| 				.filter((message) => message?.files ?? null) |  | ||||||
| 				.map((message) => |  | ||||||
| 					message.files.filter((item) => |  | ||||||
| 				['doc', 'file', 'collection', 'web_search_results'].includes(item.type) | 				['doc', 'file', 'collection', 'web_search_results'].includes(item.type) | ||||||
| 					) | 			) ?? []) | ||||||
| 				) |  | ||||||
| 				.flat(1) |  | ||||||
| 		].filter( | 		].filter( | ||||||
|  | 			// Remove duplicates | ||||||
| 			(item, index, array) => | 			(item, index, array) => | ||||||
| 				array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index | 				array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index | ||||||
| 		); | 		); | ||||||
|  |  | ||||||
|  | @ -153,6 +153,7 @@ | ||||||
| 
 | 
 | ||||||
| 			if (res) { | 			if (res) { | ||||||
| 				fileItem.status = 'processed'; | 				fileItem.status = 'processed'; | ||||||
|  | 				fileItem.collection_name = res.collection_name; | ||||||
| 				files = files; | 				files = files; | ||||||
| 			} | 			} | ||||||
| 		} catch (e) { | 		} catch (e) { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue