Skip to content

Latest commit

 

History

History
112 lines (93 loc) · 2.06 KB

rwkv-structure.md

File metadata and controls

112 lines (93 loc) · 2.06 KB

src

rwkv

init.py

# RWKV Language Model Implementation
# https://github.com/BlinkDL/RWKV-LM

import numpy as np
import torch
from typing import List

model.py

import torch
import torch.nn as nn
from torch.nn import functional as F

class RWKV_RNN(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.n_embd = args.n_embd
        self.n_layer = args.n_layer
        self.eval()

utils

config.py

import types

args = types.SimpleNamespace()
args.MODEL_NAME = '/models/RWKV-x070-Pile-168M'
args.n_layer = 12
args.n_embd = 768
args.vocab_size = 50304
args.head_size = 64

sampling.py

import torch
from torch.nn import functional as F

def sample_logits(logits, temperature=1.0, top_p=1.0, top_k=0):
    probs = F.softmax(logits.float(), dim=-1)
    sorted_probs, sorted_ids = torch.sort(probs, descending=True)

data

tokenizer.py

from tokenizers import Tokenizer

def load_tokenizer(path):
    return Tokenizer.from_file(path)

scripts

train.py

from src.rwkv.model import RWKV_RNN
from src.utils.config import args

def main():
    model = RWKV_RNN(args)
    # Training logic here

evaluate.py

import json
import math
from src.rwkv.model import RWKV_RNN

def evaluate_lambada():
    with open("misc/lambada_test.jsonl", "r", encoding="utf-8") as f:
        todo = [json.loads(line) for line in f]

tests

test_model.py

import torch
import unittest
from src.rwkv.model import RWKV_RNN

class TestRWKV(unittest.TestCase):
    def test_forward_pass(self):
        # Test implementation
        pass

requirements.txt

torch>=2.0.0
numpy>=1.20.0
tokenizers>=0.13.0

README.md

# RWKV Language Model Implementation

This is an implementation of the RWKV language model architecture, focusing on efficient inference in RNN mode.

## Setup
1. Install dependencies: `pip install -r requirements.txt`
2. Download model weights
3. Run inference: `python scripts/evaluate.py`