diff --git a/godot-rwkv.h b/godot-rwkv.h index af19b15..393ce57 100644 --- a/godot-rwkv.h +++ b/godot-rwkv.h @@ -1,41 +1,32 @@ #ifndef RWKV_GODOT_H #define RWKV_GODOT_H - #undef VK_USE_PLATFORM_XLIB_KHR - #include "rwkv.h" -#include "tokenizer/tokenizer.hpp" #include "sampler/sample.h" - - - +#include "tokenizer/tokenizer.hpp" #include "core/io/resource.h" #include "core/object/ref_counted.h" - - - - - class Agent : public Resource { GDCLASS(Agent, Resource); - public: +public: std::map state = {}; std::vector stop_sequences = {}; size_t max_queued_tokens = 0; float temperature = 3.0; float tau = 0.6; size_t last_token = 187; - RWKVTokenizer* tokenizer = nullptr; + RWKVTokenizer *tokenizer = nullptr; std::vector context = {}; std::string add_context_queue = ""; bool busy = false; - - Agent(RWKV* model, RWKVTokenizer* tokenizeri) { + size_t state_index = 0; + + Agent(RWKV *model, RWKVTokenizer *tokenizeri) { state = model->new_state(); tokenizer = tokenizeri; } @@ -45,25 +36,65 @@ class Agent : public Resource { int add_context(String contexta) { // assert that add_context_queue is empty - // assert that max_queued_tokens is 0 + if (max_queued_tokens != 0 || add_context_queue != "" || busy) { ERR_PRINT("add_context_queue is not empty or max_queued_tokens is not 0"); return -1; } - add_context_queue = std::string(contexta.utf8().get_data()); auto tokens = tokenizer->encode(add_context_queue); context.clear(); busy = true; + + // assert that max_queued_tokens is 0 + return 0; } + void sample_output(float* data, RWKV *model) { + auto out = data; + auto token = dart(data, temperature); + + max_queued_tokens = max_queued_tokens - 1; + + // check if stop sequence + bool stopped = (token == 0); + if ((context.size() > 5) && token != 0) { + std::cout << "last token: " << token << std::endl; + auto tokstocheck = std::vector(context.end() - 5, context.end()); + tokstocheck.push_back(token); + std::string context5toks = tokenizer->decode(tokstocheck); + for (size_t j = 0; j < stop_sequences.size(); j++) { + auto stop_sequence = stop_sequences[j]; + + if (context5toks.find(stop_sequence) != std::string::npos) { + max_queued_tokens = 0; + stopped = true; + } + } + } + + if (!stopped) { + model->get_state(state, state_index); + last_token = token; + context.push_back(token); + } + + if (max_queued_tokens == 0) { + busy = false; + } + } + bool is_busy() { return busy; } - void generate(int tokens){ - max_queued_tokens = (size_t)tokens; + void generate(int tokens) { + max_queued_tokens = (size_t)tokens; + } + + void clearContext(){ + add_context_queue.clear(); } void set_temperature(float temp) { @@ -86,13 +117,15 @@ class Agent : public Resource { // threadsafe return context String get_context() { + String outtext; + if (context.size() == 0) { return ""; } - auto contexta = tokenizer->decode(context); - String contexts; - contexts.parse_utf8(contexta.c_str()); - return contexts; + auto contexta = tokenizer->decode(context); + outtext.parse_utf8(contexta.c_str()); + + return outtext; } // get last token @@ -105,7 +138,7 @@ class Agent : public Resource { return (int)max_queued_tokens; } - protected: +protected: static void _bind_methods() { ClassDB::bind_method(D_METHOD("add_context", "Context"), &Agent::add_context); ClassDB::bind_method(D_METHOD("generate", "Tokens"), &Agent::generate); @@ -116,32 +149,22 @@ class Agent : public Resource { ClassDB::bind_method(D_METHOD("get_context"), &Agent::get_context); ClassDB::bind_method(D_METHOD("get_max_queued_tokens"), &Agent::get_max_queued_tokens); } - }; class GodotRWKV : public Resource { GDCLASS(GodotRWKV, Resource); - - - public: - RWKV* model = nullptr; - RWKVTokenizer* tokenizer = nullptr; + RWKV *model = nullptr; + RWKVTokenizer *tokenizer = nullptr; size_t lastToken = 187; - std::vector agents = {}; + std::vector agents = {}; GodotRWKV() { - } - void loadModel(String path, int NumThreads = 8) { - - if (NumThreads != 1 && NumThreads != 2 && NumThreads != 4 && NumThreads != 8) { - ERR_PRINT("NumThreads must be 8, 4, 2 or 1"); - return; - } - - model = new RWKV(std::string(path.utf8().get_data()), size_t(NumThreads)); + void loadModel(String path, int NumThreads = 0) { + model = new RWKV(std::string(path.utf8().get_data()), size_t(NumThreads), 0, 32); + start(); }; void loadTokenizer(String path) { @@ -149,123 +172,112 @@ class GodotRWKV : public Resource { tokenizer = new RWKVTokenizer(std::string(path.utf8().get_data())); }; - + void start() { + auto pool = get_threadpool(0); + pool->add_job([&] { + listen(); + }, + 0); + }; + void listen() { - - // sleep - // do context processing - if (agents.size() > 0) { - std::vector toProcess = {}; - for (size_t i = 0; i < agents.size(); i++) { - if (agents[i]->add_context_queue != "") { - std::cout << "processing context" << std::endl; - auto tokens = tokenizer->encode(agents[i]->add_context_queue); - std::cout << "tokens: " << tokens.size() << std::endl; - model->set_state(agents[i]->state, 0); - std::cout << "state set" << std::endl; - - - - // process tokens in batches of maxBatchSeqSize - auto tokensBatch = std::vector(); - tokensBatch.push_back(agents[i]->last_token); - for (size_t j = 0; j < tokens.size()-1; j++) { - tokensBatch.push_back(tokens[j]); - } - auto outputs = (*model)({tokensBatch}); - - agents[i]->last_token = tokens[tokens.size()-1]; - - - std::cout << "context processed" << std::endl; - - agents[i]->add_context_queue = ""; - std::cout << "context processed" << std::endl; - agents[i]->busy = false; - std::cout << "context processed busy" << std::endl; - - // std::cout << "context processed" << std::endl; - model->get_state(agents[i]->state, 0); - std::cout << "agent state retrieved" << std::endl; - } - - if (agents[i]->max_queued_tokens > 0) { - toProcess.push_back(agents[i]); - agents[i]->busy = true; - } + auto pool = get_threadpool(); + // sleep + // std::cout << "looping\n"; + // do context processing + std::vector toProcess = {}; + for (size_t i = 0; i < agents.size(); i++) { + if (agents[i]->add_context_queue != "") { + std::cout << "processing context" << std::endl; + auto tokens = tokenizer->encode(agents[i]->add_context_queue); + std::cout << "tokens: " << tokens.size() << std::endl; + model->set_state(agents[i]->state, 0); + agents[i]->state_index = 0; + std::cout << "state set" << std::endl; + + // process tokens in batches of maxBatchSeqSize + auto tokensBatch = std::vector(); + tokensBatch.push_back(agents[i]->last_token); + for (size_t j = 0; j < tokens.size()-1; j++) { + tokensBatch.push_back(tokens[j]); } + agents[i]->last_token = tokens[tokens.size()-1]; + std::cout << "starting work\n"; + auto outputs = (*model)({ tokensBatch }); + + pool->sync(); + + auto modelpointer = model; + auto ag = agents[i]; + pool->add_job( + [ ag, modelpointer] { + std::cout << "finished processing chunk:\n"; + std::cout << "starting sample\n"; + std::cout << "Clearing context queue\n"; + + ag->add_context_queue.clear(); + // agents[i]//->call_deferred("clearContext"); + std::cout << "Finsihed all\n"; + ag->busy = false; + modelpointer->get_state(ag->state,0); + }, + 0); + } - std::vector> tokens = {}; - - for (size_t i = 0; i < toProcess.size(); i++) { - tokens.push_back({toProcess[i]->last_token}); - model->set_state(toProcess[i]->state, i); - } + if (agents[i]->max_queued_tokens > 0) { + toProcess.push_back(agents[i]); + agents[i]->busy = true; + } + } - if (tokens.size() == 0) { - return; - } + std::vector> tokens = {}; - std::cout << "tokens: " << tokens.size() << std::endl; + for (size_t i = 0; i < toProcess.size(); i++) { + tokens.push_back({ toProcess[i]->last_token }); + model->set_state(toProcess[i]->state, i); + toProcess[i]->state_index = i; + } - auto outputs = (*model)(tokens); - // outputs.reshape({outputs.shape[0], size_t(pow(2, 16))}); - std::cout << "outputs: " << outputs.shape[0] << ":" << outputs.shape[1] << ":" << outputs.shape[2] << std::endl; - - for (size_t i = 0; i < toProcess.size(); i++) { - auto out = outputs[i]; - auto token = typical(flp(out.data), toProcess[i]->temperature, toProcess[i]->tau); - - toProcess[i]->max_queued_tokens -= 1; - - - // check if stop sequence - bool stopped = (token == 0); - if ((toProcess[i]->context.size() > 5) && token != 0){ - std::cout << "last token: " << token << std::endl; - auto tokstocheck = std::vector(toProcess[i]->context.end()-5, toProcess[i]->context.end()); - tokstocheck.push_back(token); - std::string context5toks = tokenizer->decode(tokstocheck); - for (size_t j = 0; j < toProcess[i]->stop_sequences.size(); j++) { - auto stop_sequence = toProcess[i]->stop_sequences[j]; - - if (context5toks.find(stop_sequence) != std::string::npos) { - toProcess[i]->max_queued_tokens = 0; - stopped = true; - } - } - } - - if (!stopped) { - model->get_state(toProcess[i]->state, i); - toProcess[i]->last_token = token; - toProcess[i]->context.push_back(token); - } - - if (toProcess[i]->max_queued_tokens == 0) { - toProcess[i]->busy = false; - } - - std::cout << "token: " << i << " processed" << std::endl; - } + std::cout << "tokens: " << tokens.size() << "\r"; + + if (tokens.size() != 0) { + auto outputs = (*model)(tokens); + + // pool->add_job([&] { + // process_output(outputs, toProcess); + // }, + // 0); + for (size_t i = 0; i < agents.size(); i++) { + auto modelpointer = model; + auto tpi = toProcess[i]; + auto data = outputs[i].data; + pool->add_job( + [tpi, modelpointer, data, i]() { + tpi->sample_output(flp(data), modelpointer); + }, + i); } - - }; + pool->sync(); + } - Variant createAgent() { + // outputs.reshape({outputs.shape[0], size_t(pow(2, 16))}); + + pool->add_job([&] { + // std::cout << "secondloop\n"; + listen(); + }, + 0); + } + Variant createAgent() { Agent *agent = new Agent(model, tokenizer); agents.push_back(agent); return Variant(agent); - - - - } - - protected: + +protected: static void _bind_methods() { - ClassDB::bind_method(D_METHOD("listen"), &GodotRWKV::listen); + // ClassDB::bind_method(D_METHOD("start"), &GodotRWKV::start); ClassDB::bind_method(D_METHOD("loadModel", "Path", "Threads"), &GodotRWKV::loadModel, DEFVAL(4)); ClassDB::bind_method(D_METHOD("loadTokenizer", "Path"), &GodotRWKV::loadTokenizer); ClassDB::bind_method(D_METHOD("createAgent"), &GodotRWKV::createAgent); @@ -274,5 +286,4 @@ class GodotRWKV : public Resource { // uninclude - #endif // RWKV_GODOT_H diff --git a/rwkv.cuh b/rwkv.cuh index 6a53d38..84df2c9 160000 --- a/rwkv.cuh +++ b/rwkv.cuh @@ -1 +1 @@ -Subproject commit 6a53d38cb65e19e57a26bee5e52c5f0134f4de39 +Subproject commit 84df2c9d376aaa6e4b89fc90df11feb629d21c7d