Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Parzival-05 committed Feb 21, 2025
1 parent d34c719 commit 954fdb8
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 28 deletions.
27 changes: 11 additions & 16 deletions VSharp.Explorer/AISearcher.fs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ namespace VSharp.Explorer

open System.Collections.Generic
open Microsoft.ML.OnnxRuntime
open System.IO
open System
open System.Net
open System.Net.Sockets
open System.Text
open System.Text.Json
open VSharp
Expand Down Expand Up @@ -231,6 +228,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
let arrayOutputJson =
JsonSerializer.Serialize arrayOutput
arrayOutputJson

let stepToString (gameState: GameState) (output: IDisposableReadOnlyCollection<OrtValue>) =
let gameStateJson =
JsonSerializer.Serialize gameState
Expand All @@ -246,22 +244,19 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
String.concat " " strToSaveAsList

let createOracleRunner (pathToONNX: string, aiAgentTrainingModelOptions: Option<AIAgentTrainingModelOptions>) =
let host = "localhost"
let port =
let stream =
match aiAgentTrainingModelOptions with
| Some options -> options.port
| None -> 0

let client = new TcpClient ()
client.Connect (host, port)
client.SendBufferSize <- 2048
let stream = client.GetStream ()
| Some options -> options.stream
| None -> None

let saveStep (gameState: GameState) (output: IDisposableReadOnlyCollection<OrtValue>) =
let bytes =
Encoding.UTF8.GetBytes (stepToString gameState output)
stream.Write (bytes, 0, bytes.Length)
stream.Flush ()
match stream with
| Some stream ->
let bytes =
Encoding.UTF8.GetBytes (stepToString gameState output)
stream.Write (bytes, 0, bytes.Length)
stream.Flush ()
| None -> ()

let sessionOptions =
if useGPU then
Expand Down
3 changes: 2 additions & 1 deletion VSharp.Explorer/Options.fs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ namespace VSharp.Explorer
open System.Diagnostics
open System.IO
open VSharp.ML.GameServer.Messages
open System.Net.Sockets

type searchMode =
| DFSMode
Expand Down Expand Up @@ -78,7 +79,7 @@ type AIAgentTrainingModelOptions =
{
aiAgentTrainingOptions: AIAgentTrainingOptions
outputDirectory: string
port: int
stream: Option<NetworkStream> // use it for sending steps
}


Expand Down
40 changes: 29 additions & 11 deletions VSharp.ML.GameServer.Runner/Main.fs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
open System.IO
open System.Net.Sockets
open System.Reflection
open Argu
open Microsoft.FSharp.Core
Expand Down Expand Up @@ -36,6 +37,9 @@ type Mode =
| Generator = 1
| SendModel = 2

let TIMEOUT_FOR_TRAINING = 15 * 60
let SOLVER_TIMEOUT_FOR_TRAINING = 2

type CliArguments =
| [<Unique>] Port of int
| [<Unique>] DatasetBasePath of string
Expand Down Expand Up @@ -241,12 +245,12 @@ let ws port outputDirectory (webSocket: WebSocket) (context: HttpContext) =

let options =
VSharpOptions (
timeout = 15 * 60,
timeout = TIMEOUT_FOR_TRAINING,
outputDirectory = outputDirectory,
searchStrategy = SearchStrategy.AI,
aiOptions = (Some aiOptions |> Option.defaultValue Unchecked.defaultof<_>),
stepsLimit = uint (stepsToPlay + stepsToStart),
solverTimeout = 2
solverTimeout = SOLVER_TIMEOUT_FOR_TRAINING
)

let explorationResult =
Expand Down Expand Up @@ -281,6 +285,9 @@ let app port outputDirectory : WebPart =
path "/gameServer" >=> handShake (ws port outputDirectory)
]

let serializeExplorationResult (explorationResult: ExplorationResult) =
$"{explorationResult.ActualCoverage} {explorationResult.TestsCount} {explorationResult.StepsCount} {explorationResult.ErrorsCount}"

let generateDataForPretraining outputDirectory datasetBasePath (maps: ResizeArray<GameMap>) stepsToSerialize =
for map in maps do
if map.StepsToStart = 0u<step> then
Expand Down Expand Up @@ -325,10 +332,7 @@ let generateDataForPretraining outputDirectory datasetBasePath (maps: ResizeArra

let explorationResult = explore map options

File.WriteAllText (
Path.Join (folderForResults, "result"),
$"{explorationResult.ActualCoverage} {explorationResult.TestsCount} {explorationResult.StepsCount} {explorationResult.ErrorsCount}"
)
File.WriteAllText (Path.Join (folderForResults, "result"), serializeExplorationResult explorationResult)

printfn
$"Generation for {map.MapName} finished with coverage {explorationResult.ActualCoverage}, tests {explorationResult.TestsCount}, steps {explorationResult.StepsCount},errors {explorationResult.ErrorsCount}."
Expand All @@ -337,7 +341,14 @@ let generateDataForPretraining outputDirectory datasetBasePath (maps: ResizeArra
API.Reset ()
HashMap.hashMap.Clear ()

let runTrainingSendModelMode outputDirectory (gameMap: GameMap) (pathToModel: string) (useGPU: bool) (optimize: bool) (port: int) =
let runTrainingSendModelMode
outputDirectory
(gameMap: GameMap)
(pathToModel: string)
(useGPU: bool)
(optimize: bool)
(port: int)
=
printfn $"Run infer on {gameMap.MapName} have started."

let aiTrainingOptions =
Expand All @@ -357,22 +368,29 @@ let runTrainingSendModelMode outputDirectory (gameMap: GameMap) (pathToModel: st
oracle = None
}

let stream =
let host = "localhost" // TODO: working within a local network
let client = new TcpClient ()
client.Connect (host, port)
client.SendBufferSize <- 2048
Some <| client.GetStream ()

let aiOptions: AIOptions =
Training (
SendModel
{
aiAgentTrainingOptions = aiTrainingOptions
outputDirectory = outputDirectory
port = port
stream = stream
}
)

let options =
VSharpOptions (
timeout = 15 * 60,
timeout = TIMEOUT_FOR_TRAINING,
outputDirectory = outputDirectory,
searchStrategy = SearchStrategy.AI,
solverTimeout = 2,
solverTimeout = SOLVER_TIMEOUT_FOR_TRAINING,
aiOptions = (Some aiOptions |> Option.defaultValue Unchecked.defaultof<_>),
pathToModel = pathToModel,
useGPU = useGPU,
Expand All @@ -384,7 +402,7 @@ let runTrainingSendModelMode outputDirectory (gameMap: GameMap) (pathToModel: st

File.WriteAllText (
Path.Join (outputDirectory, gameMap.MapName + "result"),
$"{explorationResult.ActualCoverage} {explorationResult.TestsCount} {explorationResult.StepsCount} {explorationResult.ErrorsCount}"
serializeExplorationResult explorationResult
)

printfn
Expand Down

0 comments on commit 954fdb8

Please sign in to comment.