-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpipeline.py
38 lines (31 loc) · 2.04 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
import csv
import argparse
from utils_pipeline import load_api_credentials, initialize_medrag_model, process_json
def main():
# Set up command-line argument parsing
parser = argparse.ArgumentParser(description="Run MedRAG model on CSV data and generate answers.")
parser.add_argument("--input_json_path", default="../MIRAGE_alan/benchmark_partial_100_medqa.json",help="Path to the input CSV file.")
parser.add_argument("--output_folder_path", default="../MIRAGE_alan/prediction_amedrag/",help="Path to the input CSV file.")
parser.add_argument("--dataset_name", default="pubmedqa")
parser.add_argument("--model_name", default="openai/gpt-4-turbo", help="Name of the MedRAG model to use.")
parser.add_argument("--store_desc", default="amedrag/OpenAI", help="Name of the MedRAG model to use.")
parser.add_argument("--is_rag", action="store_true", help="Whether to use the RAG model.") # default leave it False
parser.add_argument("--is_agent", action="store_true", help="Whether to use the RAG-agent model.") # default leave it False
parser.add_argument("--is_iterative", action="store_true", help="Whether to use the iterative approach.") # default leave it False
parser.add_argument("--start_idx", default=0, type=int, help="Starting index for processing.")
parser.add_argument("--segment_size", default=160, type=int, help="Number of questions to process at a time.")
args = parser.parse_args()
print("start processing")
# Load API credentials
# load_api_credentials()
# Initialize the MedRAG model
# cot_gpt_4o_mini = initialize_medrag_model(args.model_name, args.is_rag, args.is_agent)
medRAG_model = initialize_medrag_model(args.model_name, args.is_rag, args.is_iterative, args.is_agent)
# Process CSV and generate answers
process_json(
args.input_json_path, args.output_folder_path,
medRAG_model, args.dataset_name, args.store_desc,
args.model_name, args.is_rag, args.is_iterative, args.start_idx, args.segment_size)
if __name__ == "__main__":
main()