fix: rag
This commit is contained in:
		
							parent
							
								
									bcc27e3852
								
							
						
					
					
						commit
						514c7f1520
					
				|  | @ -170,7 +170,9 @@ app.state.MODELS = {} | |||
| origins = ["*"] | ||||
| 
 | ||||
| 
 | ||||
| async def get_function_call_response(messages, tool_id, template, task_model_id, user): | ||||
| async def get_function_call_response( | ||||
|     messages, files, tool_id, template, task_model_id, user | ||||
| ): | ||||
|     tool = Tools.get_tool_by_id(tool_id) | ||||
|     tools_specs = json.dumps(tool.specs, indent=2) | ||||
|     content = tools_function_calling_generation_template(template, tools_specs) | ||||
|  | @ -265,6 +267,13 @@ async def get_function_call_response(messages, tool_id, template, task_model_id, | |||
|                             "__messages__": messages, | ||||
|                         } | ||||
| 
 | ||||
|                     if "__files__" in sig.parameters: | ||||
|                         # Call the function with the '__files__' parameter included | ||||
|                         params = { | ||||
|                             **params, | ||||
|                             "__files__": files, | ||||
|                         } | ||||
| 
 | ||||
|                     function_result = function(**params) | ||||
|                 except Exception as e: | ||||
|                     print(e) | ||||
|  | @ -338,6 +347,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): | |||
|                     try: | ||||
|                         response = await get_function_call_response( | ||||
|                             messages=data["messages"], | ||||
|                             files=data.get("files", []), | ||||
|                             tool_id=tool_id, | ||||
|                             template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, | ||||
|                             task_model_id=task_model_id, | ||||
|  | @ -353,7 +363,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): | |||
| 
 | ||||
|                 print(f"tool_context: {context}") | ||||
| 
 | ||||
|             # If docs field is present, generate RAG completions | ||||
|             # TODO: Check if tools & functions have files support to skip this step to delegate file processing | ||||
|             # If files field is present, generate RAG completions | ||||
|             if "files" in data: | ||||
|                 data = {**data} | ||||
|                 rag_context, citations = get_rag_context( | ||||
|  | @ -376,15 +387,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): | |||
|                 system_prompt = rag_template( | ||||
|                     rag_app.state.config.RAG_TEMPLATE, context, prompt | ||||
|                 ) | ||||
| 
 | ||||
|                 print(system_prompt) | ||||
| 
 | ||||
|                 data["messages"] = add_or_update_system_message( | ||||
|                     f"\n{system_prompt}", data["messages"] | ||||
|                 ) | ||||
| 
 | ||||
|             modified_body_bytes = json.dumps(data).encode("utf-8") | ||||
| 
 | ||||
|             # Replace the request body with the modified one | ||||
|             request._body = modified_body_bytes | ||||
|             # Set custom header to ensure content-length matches new body length | ||||
|  | @ -961,7 +969,12 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ | |||
| 
 | ||||
|     try: | ||||
|         context = await get_function_call_response( | ||||
|             form_data["messages"], form_data["tool_id"], template, model_id, user | ||||
|             form_data["messages"], | ||||
|             form_data.get("files", []), | ||||
|             form_data["tool_id"], | ||||
|             template, | ||||
|             model_id, | ||||
|             user, | ||||
|         ) | ||||
|         return context | ||||
|     except Exception as e: | ||||
|  |  | |||
|  | @ -587,22 +587,17 @@ | |||
| 		}); | ||||
| 
 | ||||
| 		let files = []; | ||||
| 
 | ||||
| 		if (model?.info?.meta?.knowledge ?? false) { | ||||
| 			files = model.info.meta.knowledge; | ||||
| 		} | ||||
| 
 | ||||
| 		const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1); | ||||
| 		files = [ | ||||
| 			...files, | ||||
| 			...messages | ||||
| 				.filter((message) => message?.files ?? null) | ||||
| 				.map((message) => | ||||
| 					message.files.filter((item) => | ||||
| 			...(lastUserMessage?.files?.filter((item) => | ||||
| 				['doc', 'file', 'collection', 'web_search_results'].includes(item.type) | ||||
| 					) | ||||
| 				) | ||||
| 				.flat(1) | ||||
| 			) ?? []) | ||||
| 		].filter( | ||||
| 			// Remove duplicates | ||||
| 			(item, index, array) => | ||||
| 				array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index | ||||
| 		); | ||||
|  | @ -832,22 +827,17 @@ | |||
| 		const responseMessage = history.messages[responseMessageId]; | ||||
| 
 | ||||
| 		let files = []; | ||||
| 
 | ||||
| 		if (model?.info?.meta?.knowledge ?? false) { | ||||
| 			files = model.info.meta.knowledge; | ||||
| 		} | ||||
| 
 | ||||
| 		const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1); | ||||
| 		files = [ | ||||
| 			...files, | ||||
| 			...messages | ||||
| 				.filter((message) => message?.files ?? null) | ||||
| 				.map((message) => | ||||
| 					message.files.filter((item) => | ||||
| 			...(lastUserMessage?.files?.filter((item) => | ||||
| 				['doc', 'file', 'collection', 'web_search_results'].includes(item.type) | ||||
| 					) | ||||
| 				) | ||||
| 				.flat(1) | ||||
| 			) ?? []) | ||||
| 		].filter( | ||||
| 			// Remove duplicates | ||||
| 			(item, index, array) => | ||||
| 				array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index | ||||
| 		); | ||||
|  |  | |||
|  | @ -153,6 +153,7 @@ | |||
| 
 | ||||
| 			if (res) { | ||||
| 				fileItem.status = 'processed'; | ||||
| 				fileItem.collection_name = res.collection_name; | ||||
| 				files = files; | ||||
| 			} | ||||
| 		} catch (e) { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue