From c05ce472299e47f577b4ab705f321ae6a6891014 Mon Sep 17 00:00:00 2001 From: gsv Date: Wed, 17 Jan 2024 19:14:18 +0300 Subject: [PATCH] Prepare data for ONNX model. --- VSharp.API/VSharpOptions.cs | 4 +- VSharp.Explorer/AISearcher.fs | 152 +++++++++++++++++++++++++ VSharp.Explorer/VSharp.Explorer.fsproj | 2 + VSharp.Runner/RunnerProgram.cs | 8 +- 4 files changed, 160 insertions(+), 6 deletions(-) diff --git a/VSharp.API/VSharpOptions.cs b/VSharp.API/VSharpOptions.cs index 68944b1ff..2433babad 100644 --- a/VSharp.API/VSharpOptions.cs +++ b/VSharp.API/VSharpOptions.cs @@ -112,7 +112,7 @@ public readonly record struct VSharpOptions public readonly int RandomSeed = DefaultRandomSeed; public readonly uint StepsLimit = DefaultStepsLimit; public readonly Oracle? Oracle = null; - public readonly AIAgentTrainingOptions? AIAgentTrainingOptions = null; + public readonly AIAgentTrainingOptions AIAgentTrainingOptions = null; /// /// Symbolic virtual machine options. @@ -143,7 +143,7 @@ public VSharpOptions( int randomSeed = DefaultRandomSeed, uint stepsLimit = DefaultStepsLimit, Oracle? oracle = null, - AIAgentTrainingOptions? aiAgentTrainingOptions = null) + AIAgentTrainingOptions aiAgentTrainingOptions = null) { Timeout = timeout; SolverTimeout = solverTimeout; diff --git a/VSharp.Explorer/AISearcher.fs b/VSharp.Explorer/AISearcher.fs index 74250ef09..a0c372f45 100644 --- a/VSharp.Explorer/AISearcher.fs +++ b/VSharp.Explorer/AISearcher.fs @@ -1,6 +1,7 @@ namespace VSharp.Explorer open System.Collections.Generic +open Microsoft.ML.OnnxRuntime open VSharp open VSharp.IL.Serializer open VSharp.ML.GameServer.Messages @@ -144,6 +145,157 @@ type internal AISearcher(oracle:Oracle, aiAgentTrainingOptions: Option,int>() + let networkInput = + let verticesIds = Dictionary,int>() + let res = Dictionary<_,_>() + let gameVertices = + let shape = [| int64 gameState.GraphVertices.Length; numOfVertexAttributes |] + let attributes = Array.zeroCreate (gameState.GraphVertices.Length * numOfVertexAttributes) + for i in 0..gameState.GraphVertices.Length - 1 do + let v = gameState.GraphVertices.[i] + verticesIds.Add(v.Id,i) + let i = i*numOfVertexAttributes + attributes.[i] <- if v.InCoverageZone then 1u else 0u + attributes.[i + 1] <- v.BasicBlockSize + attributes.[i + 2] <- if v.CoveredByTest then 1u else 0u + attributes.[i + 3] <- if v.VisitedByState then 1u else 0u + attributes.[i + 4] <- if v.TouchedByState then 1u else 0u + attributes.[i + 5] <- if v.ContainsCall then 1u else 0u + attributes.[i + 6] <- if v.ContainsThrow then 1u else 0u + OrtValue.CreateTensorValueFromMemory(attributes, shape) + + let states, numOfParentOfEdges, numOfHistoryEdges = + let mutable numOfParentOfEdges = 0 + let mutable numOfHistoryEdges = 0 + let shape = [| int64 gameState.States.Length; numOfStateAttributes |] + let attributes = Array.zeroCreate (gameState.States.Length * numOfStateAttributes) + for i in 0..gameState.States.Length - 1 do + let v = gameState.States.[i] + numOfHistoryEdges <- numOfHistoryEdges + v.History.Length + numOfParentOfEdges <- numOfParentOfEdges + v.Children.Length + stateIds.Add(v.Id,i) + let i = i*numOfStateAttributes + attributes.[i] <- uint v.Position + attributes.[i + 1] <- uint v.PathConditionSize + attributes.[i + 2] <- uint v.VisitedAgainVertices + attributes.[i + 3] <- uint v.VisitedNotCoveredVerticesInZone + attributes.[i + 4] <- uint v.VisitedNotCoveredVerticesOutOfZone + attributes.[i + 5] <- uint v.InstructionsVisitedInCurrentBlock + attributes.[i + 6] <- uint v.StepWhenMovedLastTime + OrtValue.CreateTensorValueFromMemory(attributes, shape) + ,numOfParentOfEdges + ,numOfHistoryEdges + + let vertexToVertexEdgesIndex,vertexToVertexEdgesAttributes = + let shapeOfIndex = [| 2L; gameState.Map.Length |] + let shapeOfAttributes = [| 1L; gameState.Map.Length |] + let index = Array.zeroCreate (2 * gameState.Map.Length) + let attributes = Array.zeroCreate gameState.Map.Length + gameState.Map + |> Array.iteri ( + fun i e -> + index[i * 2] <- verticesIds[e.VertexFrom] + index[i * 2 + 1] <- verticesIds[e.VertexTo] + attributes[i] <- e.Label.Token + ) + + OrtValue.CreateTensorValueFromMemory(index, shapeOfIndex) + , OrtValue.CreateTensorValueFromMemory(attributes, shapeOfAttributes) + + let historyEdgesIndex_vertexToState, historyEdgesAttributes, parentOfEdges = + let shapeOfParentOf = [| 2L; numOfParentOfEdges |] + let parentOf = Array.zeroCreate (2 * numOfParentOfEdges) + let shapeOfHistory = [|2L; numOfHistoryEdges|] + let historyIndex_vertexToState = Array.zeroCreate (2 * numOfHistoryEdges) + let shapeOfHistoryAttributes = [| int64 numOfHistoryEdgeAttributes; numOfHistoryEdges |] + let historyAttributes = Array.zeroCreate (2 * numOfHistoryEdges) + let mutable firstFreePositionInParentsOf = 0 + let mutable firstFreePositionInHistoryIndex = 0 + let mutable firstFreePositionInHistoryAttributes = 0 + gameState.States + |> Array.iter (fun v -> + v.Children + |> Array.iteri (fun i s -> + let i = firstFreePositionInParentsOf + 2 * i + parentOf[i] <- stateIds[v.Id] + parentOf[i + 1] <- stateIds[s] + ) + firstFreePositionInParentsOf <- firstFreePositionInParentsOf + 2 * v.Children.Length + v.History + |> Array.iteri (fun i s -> + let j = firstFreePositionInHistoryIndex + 2 * i + historyIndex_vertexToState[j] <- verticesIds[s.GraphVertexId] + historyIndex_vertexToState[j + 1] <- stateIds[v.Id] + let j = firstFreePositionInHistoryAttributes + numOfHistoryEdgeAttributes * i + historyAttributes[j] <- s.NumOfVisits + historyAttributes[j] <- uint s.StepWhenVisitedLastTime + ) + firstFreePositionInHistoryIndex <- firstFreePositionInHistoryIndex + 2 * v.History.Length + firstFreePositionInHistoryAttributes <- firstFreePositionInHistoryAttributes + numOfHistoryEdgeAttributes * v.History.Length + ) + OrtValue.CreateTensorValueFromMemory(historyIndex_vertexToState, shapeOfHistory) + , OrtValue.CreateTensorValueFromMemory(historyAttributes, shapeOfHistoryAttributes) + , OrtValue.CreateTensorValueFromMemory(parentOf, shapeOfParentOf) + + let statePosition_stateToVertex, statePosition_vertexToState = + let data_stateToVertex = Array.zeroCreate (2 * gameState.States.Length) + let data_vertexToState = Array.zeroCreate (2 * gameState.States.Length) + let shape = [|2L; gameState.States.Length|] + let mutable firstFreePosition = 0 + gameState.GraphVertices + |> Array.iter ( + fun v -> + v.States + |> Array.iteri (fun i s -> + let startPos = firstFreePosition + i * 2 + let s = stateIds[s] + let v= verticesIds[v.Id] + data_stateToVertex[startPos] <- s + data_stateToVertex[startPos + 1] <- v + + data_vertexToState[startPos] <- v + data_vertexToState[startPos + 1] <- s + ) + firstFreePosition <- firstFreePosition + 2 * v.States.Length + ) + OrtValue.CreateTensorValueFromMemory(data_stateToVertex, shape) + ,OrtValue.CreateTensorValueFromMemory(data_vertexToState, shape) + + res.Add ("game_vertex", gameVertices) + res.Add ("state_vertex", states) + res.Add ("game_vertex to game_vertex", vertexToVertexEdgesIndex) + res.Add ("game_vertex history state_vertex index", historyEdgesIndex_vertexToState) + res.Add ("game_vertex history state_vertex attrs", historyEdgesAttributes) + res.Add ("game_vertex in state_vertex", statePosition_vertexToState) + res.Add ("state_vertex parent_of state_vertex", parentOfEdges) + res + + let output = session.Run(runOptions, networkInput, session.OutputNames) + let weighedStates = output[0].GetTensorDataAsSpan().ToArray() + + let id = + weighedStates + |> Array.mapi (fun i v -> i,v) + |> Array.maxBy snd + |> fst + stateIds + |> Seq.find (fun kvp -> kvp.Value = id) + |> fun x -> x.Key + + Oracle(predict,feedback) + + AISearcher(createOracle pathToONNX, None) + interface IForwardSearcher with override x.Init states = init states override x.Pick() = pick (always true) diff --git a/VSharp.Explorer/VSharp.Explorer.fsproj b/VSharp.Explorer/VSharp.Explorer.fsproj index 611f5d047..818c48218 100644 --- a/VSharp.Explorer/VSharp.Explorer.fsproj +++ b/VSharp.Explorer/VSharp.Explorer.fsproj @@ -46,5 +46,7 @@ + + diff --git a/VSharp.Runner/RunnerProgram.cs b/VSharp.Runner/RunnerProgram.cs index f6d0275ec..2729cf74d 100644 --- a/VSharp.Runner/RunnerProgram.cs +++ b/VSharp.Runner/RunnerProgram.cs @@ -92,10 +92,10 @@ public static class RunnerProgram //method = type.GetMethod(t.Last(), Reflection.allBindingFlags); //method = type.GetMethod(methodArgumentValue, Reflection.allBindingFlags); var x = type.GetMethods(Reflection.allBindingFlags); - foreach (var m in x) - { - // Console.WriteLine($"{type.FullName}.{m.Name}"); - } + //foreach (var m in x) + //{ + // Console.WriteLine($"{type.FullName}.{m.Name}"); + //} method ??= x .Where(m => type.FullName.Split('.').Last().Contains(className) && m.Name.Contains(methodName)) .MinBy(m => m.Name.Length);