From 1c685bb8c35ca34d62a1e28d6518f925b369a0b3 Mon Sep 17 00:00:00 2001 From: Riccardo Balbo Date: Sun, 12 May 2024 10:07:45 +0200 Subject: [PATCH] Add tools support --- src/rag/index.js | 44 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/src/rag/index.js b/src/rag/index.js index 7ae4db8..85ade17 100644 --- a/src/rag/index.js +++ b/src/rag/index.js @@ -12,6 +12,8 @@ async function run() { let cacheDurationHint = undefined; let noCache = false; let warmUp = false; + let useTools=false; + let toolsResultTemplate ="{{TOOL_RESULT}}"; const documents = []; // plain text documents const documentsUrls = []; // urls to documents @@ -29,6 +31,12 @@ async function run() { } } } + + // Early return if no documents + if (documents.length === 0 && documentsUrls.length === 0) { + Host.outputString(""); + return; + } for(const param of job.param){ @@ -46,10 +54,13 @@ async function run() { }else if(param.key=="cache-duration-hint"){ cacheDurationHint=param.value; }else if(param.key=="no-cache"){ - noCache=param.value=="true"; + noCache = param.value=="true"; }else if(param.key=="warm-up"){ - noCache||=param.value=="true"; - warmUp||=param.value=="true"; + warmUp = param.value == "true"; + }else if(param.key=="use-tools"){ + useTools=param.value=="true"; + }else if(param.key=="tools-result-template"){ + toolsResultTemplate=param.value; } } @@ -57,10 +68,26 @@ async function run() { if(cacheDurationHint){ cacheParams.push(await Job.newParam("cache-duration-hint", cacheDurationHint)); } - if(noCache){ + if(noCache||warmUp){ cacheParams.push(await Job.newParam("no-cache", noCache)); } + // Send tool req + let toolReq; + if (!warmUp && useTools){ + Job.log("Send tool request..."); + toolReq = Job.subrequest({ + runOn: "openagents/tools", + outputFormat: "application/json", + inputs: [ + await Job.newInputData(queries, "text", "queries") + ], + params: [ + ...cacheParams + ] + }); + } + Job.log("Starting rag pipeline with k="+topK+", max-tokens="+maxTokens+", quantize="+quantize+", overlap="+overlap+", cache-duration-hint="+cacheDurationHint+", no-cache="+noCache); Job.log("Fetch documents..."); const downloadDocumentsReq = Job.subrequest({ @@ -108,7 +135,7 @@ async function run() { Job.log("Search..."); const searchReq = Job.subrequest({ runOn: "openagents/search", - outputFormat: "application/hyperdrive+bundle", + outputFormat: "application/json", inputs: [ await Job.newInputData(documentsEmbeddingBundle, "application/hyperdrive+bundle", "index"), await Job.newInputData(queriesEmbeddingBundle, quantize?"application/json":"application/hyperdrive+bundle", "query") @@ -121,13 +148,18 @@ async function run() { }); try{ - const searchResult = JSON.parse(await Job.waitForContent(searchReq)); Job.log("Merge context... "+searchResult.length+" results found"); let newContext =""; for(const result of searchResult){ newContext+=result.value+"\n"; } + if(toolReq){ + Job.log("Merge tools result..."); + let toolResult = await Job.waitForContent(toolReq); + toolResult = toolsResultTemplate.replace("{{TOOL_RESULT}}", toolResult); + newContext+=toolResult+"\n"; + } Host.outputString(newContext); }catch(e){ await Job.log("Error! "+e.message);