Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hook for chain execution tracing #98

Merged
merged 1 commit into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions LangChain.sln
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,13 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Google.
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Azure", "src\libs\Providers\LangChain.Providers.Azure\LangChain.Providers.Azure.csproj", "{18F5AAB1-1750-41BD-B623-6339CA5754D9}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Providers.Ollama.IntegrationTests", "src\tests\LangChain.Providers.Ollama.IntegrationTests\LangChain.Providers.Ollama.IntegrationTests.csproj", "{72B1E2CC-1A34-470E-A579-034CB0972BB7}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Ollama.IntegrationTests", "src\tests\LangChain.Providers.Ollama.IntegrationTests\LangChain.Providers.Ollama.IntegrationTests.csproj", "{72B1E2CC-1A34-470E-A579-034CB0972BB7}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Ollama", "src\libs\Providers\LangChain.Providers.Ollama\LangChain.Providers.Ollama.csproj", "{4913844F-74EC-4E74-AE8A-EA825569E6BA}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Providers.Automatic1111", "src\libs\Providers\LangChain.Providers.Automatic1111\LangChain.Providers.Automatic1111.csproj", "{BF4C7B87-0997-4208-84EF-D368DF7B9861}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Automatic1111", "src\libs\Providers\LangChain.Providers.Automatic1111\LangChain.Providers.Automatic1111.csproj", "{BF4C7B87-0997-4208-84EF-D368DF7B9861}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Providers.Automatic1111.IntegrationTests", "src\tests\LangChain.Providers.Automatic1111.IntegrationTests\LangChain.Providers.Automatic1111.IntegrationTests.csproj", "{A6CF79BC-8365-46E8-9230-1A4AD615D40B}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Providers.Azure", "src\libs\Providers\LangChain.Providers.Azure\LangChain.Providers.Azure.csproj", "{738984A2-7D3F-42E7-9B4D-3528E2539197}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Automatic1111.IntegrationTests", "src\tests\LangChain.Providers.Automatic1111.IntegrationTests\LangChain.Providers.Automatic1111.IntegrationTests.csproj", "{A6CF79BC-8365-46E8-9230-1A4AD615D40B}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Expand Down Expand Up @@ -404,10 +402,6 @@ Global
{A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Release|Any CPU.Build.0 = Release|Any CPU
{738984A2-7D3F-42E7-9B4D-3528E2539197}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{738984A2-7D3F-42E7-9B4D-3528E2539197}.Debug|Any CPU.Build.0 = Debug|Any CPU
{738984A2-7D3F-42E7-9B4D-3528E2539197}.Release|Any CPU.ActiveCfg = Release|Any CPU
{738984A2-7D3F-42E7-9B4D-3528E2539197}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -475,7 +469,6 @@ Global
{4913844F-74EC-4E74-AE8A-EA825569E6BA} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
{BF4C7B87-0997-4208-84EF-D368DF7B9861} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
{A6CF79BC-8365-46E8-9230-1A4AD615D40B} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
{738984A2-7D3F-42E7-9B4D-3528E2539197} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {5C00D0F1-6138-4ED9-846B-97E43D6DFF1C}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using LangChain.Abstractions.Schema;
using LangChain.Callback;
using LangChain.Chains.HelperChains.Exceptions;
using LangChain.Chains.StackableChains.Context;
using LangChain.Schema;

namespace LangChain.Chains.HelperChains;
Expand Down Expand Up @@ -86,10 +87,13 @@ string FormatInputValues(IChainValues values)
public Task<IChainValues> CallAsync(IChainValues values, ICallbacks? callbacks = null,
IReadOnlyList<string>? tags = null, IReadOnlyDictionary<string, object>? metadata = null)
{


if (values == null)
{
throw new ArgumentNullException(nameof(values));
}

try
{
return InternalCall(values);
Expand All @@ -108,8 +112,9 @@ public Task<IChainValues> CallAsync(IChainValues values, ICallbacks? callbacks =

throw new StackableChainException(message, ex);
}

}

/// <summary>
///
/// </summary>
Expand Down Expand Up @@ -143,9 +148,11 @@ public static StackChain BitwiseOr(BaseStackableChain left, BaseStackableChain r
///
/// </summary>
/// <returns></returns>
public async Task<IChainValues> Run()
public async Task<IChainValues> Run(StackableChainHook? hook=null)
{
var res = await CallAsync(new ChainValues()).ConfigureAwait(false);
var values = new StackableChainValues() {Hook = hook};
hook?.ChainStart(values);
var res = await CallAsync(values).ConfigureAwait(false);
return res;
}

Expand All @@ -154,9 +161,9 @@ public async Task<IChainValues> Run()
/// </summary>
/// <param name="resultKey"></param>
/// <returns></returns>
public async Task<string?> Run(string resultKey)
public async Task<string?> Run(string resultKey, StackableChainHook? hook = null)
{
var res = await CallAsync(new ChainValues()).ConfigureAwait(false);
var res = await CallAsync(new StackableChainValues() { Hook = hook }).ConfigureAwait(false);
return res.Value[resultKey].ToString();
}

Expand All @@ -166,12 +173,17 @@ public async Task<IChainValues> Run()
/// <param name="resultKey"></param>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public async Task<T> Run<T>(string resultKey)
public async Task<T> Run<T>(string resultKey, StackableChainHook? hook = null)
{
var res = await CallAsync(new ChainValues()).ConfigureAwait(false);
var res = await CallAsync(new StackableChainValues() { Hook = hook }).ConfigureAwait(false);
return (T)res.Value[resultKey];
}

public Task<string?> Run(string resultKey)
{
return Run(resultKey, null);
}

/// <summary>
///
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@

using LangChain.Chains.HelperChains;

namespace LangChain.Chains.StackableChains.Context;

public class ConsoleTraceHook: StackableChainHook

Check warning on line 6 in src/libs/LangChain.Core/Chains/StackableChains/Hooks/ConsoleTraceHook.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Missing XML comment for publicly visible type or member 'ConsoleTraceHook'
{
public bool UseColors { get; set; }=true;

Check warning on line 8 in src/libs/LangChain.Core/Chains/StackableChains/Hooks/ConsoleTraceHook.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Missing XML comment for publicly visible type or member 'ConsoleTraceHook.UseColors'
public int ValuesLength { get; set; } = 40;

Check warning on line 9 in src/libs/LangChain.Core/Chains/StackableChains/Hooks/ConsoleTraceHook.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Missing XML comment for publicly visible type or member 'ConsoleTraceHook.ValuesLength'
public override void ChainStart(StackableChainValues values)

Check warning on line 10 in src/libs/LangChain.Core/Chains/StackableChains/Hooks/ConsoleTraceHook.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Missing XML comment for publicly visible type or member 'ConsoleTraceHook.ChainStart(StackableChainValues)'
{
Console.WriteLine();
}
public override void LinkEnter(BaseStackableChain chain, StackableChainValues values)

Check warning on line 14 in src/libs/LangChain.Core/Chains/StackableChains/Hooks/ConsoleTraceHook.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Missing XML comment for publicly visible type or member 'ConsoleTraceHook.LinkEnter(BaseStackableChain, StackableChainValues)'
{

Console.Write("|");
Console.Write(chain.GetType().Name);
Console.WriteLine();
if (chain.InputKeys.Count > 0)
{
Console.Write("|");
Console.Write(" ");
Console.Write("\u2514");
Console.Write("Input:");
Console.WriteLine();
foreach (string inputKey in chain.InputKeys)
{
Console.Write("|");
Console.Write(" ");
Console.Write("\u2514");
var value = values.Value[inputKey];
var oldColor = Console.ForegroundColor;
Console.ForegroundColor = GetColorForKey(inputKey);
Console.Write($" {inputKey}={ShortenString(value.ToString() ?? "", ValuesLength)}");
Console.ForegroundColor = oldColor;
Console.WriteLine();
}
}


}

Dictionary<string, ConsoleColor> _colorMap = new Dictionary<string, ConsoleColor>();

ConsoleColor GetColorForKey(string key)
{
if(!UseColors)
return Console.ForegroundColor;
// if key is not in map, get unique color(except black and white)
// if there no unique colors left, return white
if (!_colorMap.ContainsKey(key))
{
var color = ConsoleColor.White;
var colors = Enum.GetValues(typeof(ConsoleColor));
foreach (ConsoleColor c in colors)
{
if (c == ConsoleColor.Black || c == ConsoleColor.White)
continue;
if (!_colorMap.ContainsValue(c))
{
color = c;
break;
}
}
_colorMap.Add(key, color);
}
return _colorMap[key];
}

string ShortenString(string str, int length)
{
if (str.Length <= length)
return str;
return str.Substring(0, length - 3) + "...";
}

public override void LinkExit(BaseStackableChain chain, StackableChainValues values)

Check warning on line 78 in src/libs/LangChain.Core/Chains/StackableChains/Hooks/ConsoleTraceHook.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Missing XML comment for publicly visible type or member 'ConsoleTraceHook.LinkExit(BaseStackableChain, StackableChainValues)'
{
if (chain.OutputKeys.Count > 0)
{
Console.Write("|");
Console.Write(" ");
Console.Write("\u2514");
Console.Write("Output:");
Console.WriteLine();
foreach (string outputKey in chain.OutputKeys)
{
Console.Write("|");
Console.Write(" ");
Console.Write("\u2514");
var value = values.Value[outputKey];
var oldColor = Console.ForegroundColor;
Console.ForegroundColor = GetColorForKey(outputKey);
Console.Write($" {outputKey}={ShortenString(value.ToString() ?? "", ValuesLength)}");
Console.ForegroundColor = oldColor;
Console.WriteLine();
}
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using LangChain.Chains.HelperChains;

namespace LangChain.Chains.StackableChains.Context;

public class StackableChainHook
{
public virtual void ChainStart(StackableChainValues values)
{

}

public virtual void LinkEnter(BaseStackableChain chain, StackableChainValues values)
{

}

public virtual void LinkExit(BaseStackableChain chain, StackableChainValues values)
{

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using LangChain.Schema;

namespace LangChain.Chains.StackableChains.Context;

public class StackableChainValues : ChainValues
{
public StackableChainHook? Hook { get; set; }
}
13 changes: 12 additions & 1 deletion src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LangChain.Abstractions.Schema;
using LangChain.Chains.StackableChains.Context;
using LangChain.Schema;

namespace LangChain.Chains.HelperChains;
Expand Down Expand Up @@ -60,15 +61,25 @@

if (IsolatedInputKeys.Count > 0)
{
var res = new ChainValues();
var res = new StackableChainValues(){Hook = (values as StackableChainValues)?.Hook};
foreach (var key in IsolatedInputKeys)
{
res.Value[key] = values.Value[key];
}
values = res;
}
if(a is not StackChain)
(values as StackableChainValues)?.Hook?.LinkEnter(a, (values as StackableChainValues));

Check warning on line 72 in src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Possible null reference argument for parameter 'values' in 'void StackableChainHook.LinkEnter(BaseStackableChain chain, StackableChainValues values)'.
await a.CallAsync(values).ConfigureAwait(false);
if (a is not StackChain)
(values as StackableChainValues)?.Hook?.LinkExit(a, (values as StackableChainValues));

Check warning on line 75 in src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Possible null reference argument for parameter 'values' in 'void StackableChainHook.LinkExit(BaseStackableChain chain, StackableChainValues values)'.

if (b is not StackChain)
(values as StackableChainValues)?.Hook?.LinkEnter(b, (values as StackableChainValues));

Check warning on line 78 in src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Possible null reference argument for parameter 'values' in 'void StackableChainHook.LinkEnter(BaseStackableChain chain, StackableChainValues values)'.
await b.CallAsync(values).ConfigureAwait(false);
if (b is not StackChain)
(values as StackableChainValues)?.Hook?.LinkExit(b, (values as StackableChainValues));

Check warning on line 81 in src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Possible null reference argument for parameter 'values' in 'void StackableChainHook.LinkExit(BaseStackableChain chain, StackableChainValues values)'.

if (IsolatedOutputKeys.Count > 0)
{
foreach (var key in IsolatedOutputKeys)
Expand Down