Skip to content

Commit

Permalink
notation changes
Browse files Browse the repository at this point in the history
  • Loading branch information
krypticmouse committed May 20, 2024
1 parent 1da68fa commit 1d46623
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 42 deletions.
24 changes: 12 additions & 12 deletions pirate/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def save(self, path: str):
ext = path.split(".")[-1]

if ext == "json" or ext == "jsonl":
self.to_json(path)
self._to_json(path)
elif ext == "csv":
self.to_csv(path)
self._to_csv(path)
else:
raise NotImplementedError(f"Extension {ext} not supported")

Expand All @@ -65,22 +65,22 @@ def load(self, data: Union[str, List, Mapping]):
ext = data.split(".")[-1]

if ext == "json" or ext == "jsonl":
self.data = self.from_json(data)
self.data = self._from_json(data)
elif ext == "csv":
self.data = self.from_csv(data)
self.data = self._from_csv(data)
else:
raise NotImplementedError(f"Extension {ext} not supported")

elif isinstance(data, list):
self.data = self.from_list(data)
self.data = self._from_list(data)

elif isinstance(data, dict):
self.data = self.from_dict(data)
self.data = self._from_dict(data)

else:
raise NotImplementedError(f"Type {type(data)} not supported")

def to_json(self, path: str):
def _to_json(self, path: str):
"""
Save the data to a JSON file.
Expand All @@ -91,7 +91,7 @@ def to_json(self, path: str):
for k, v in self.data.items():
f.write(json.dumps({self.id_key: k, self.content_key: v}) + "\n")

def to_csv(self, path: str):
def _to_csv(self, path: str):
"""
Save the data to a CSV file.
Expand All @@ -102,7 +102,7 @@ def to_csv(self, path: str):
for k, v in self.data.items():
f.write(f"{k},{v}\n")

def from_dict(self, data: Mapping) -> Mapping:
def _from_dict(self, data: Mapping) -> Mapping:
"""
Load data from a dictionary.
Expand All @@ -111,7 +111,7 @@ def from_dict(self, data: Mapping) -> Mapping:
"""
return data

def from_list(self, data: List) -> Mapping:
def _from_list(self, data: List) -> Mapping:
"""
Load data from a list.
Expand All @@ -128,7 +128,7 @@ def from_list(self, data: List) -> Mapping:

return mapped_data

def from_json(self, data: str):
def _from_json(self, data: str):
"""
Load data from a JSON file.
Expand All @@ -145,7 +145,7 @@ def from_json(self, data: str):

return mapped_data

def from_csv(self, data: str):
def _from_csv(self, data: str):
"""
Load data from a CSV file.
Expand Down
44 changes: 22 additions & 22 deletions pirate/data/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ def load(self, ranking: Union[str, List]):
ext = ranking.split(".")[-1]

if ext == "json" or ext == "jsonl":
self.ranking = self.from_json(ranking)
self.ranking = self._from_json(ranking)
elif ext == "csv":
self.ranking = self.from_csv(ranking)
self.ranking = self._from_csv(ranking)
else:
raise NotImplementedError(f"Extension {ext} not supported")

elif isinstance(ranking, list):
self.from_list(ranking)
self._from_list(ranking)
else:
raise NotImplementedError(f"Type {type(ranking)} not supported")

Expand All @@ -54,13 +54,25 @@ def save(self, path: str):
ext = path.split(".")[-1]

if ext == "json" or ext == "jsonl":
self.to_json(path)
self._to_json(path)
elif ext == "csv":
self.to_csv(path)
self._to_csv(path)
else:
raise NotImplementedError(f"Extension {ext} not supported")

def get_passage_groups(self, qid: str) -> pl.DataFrame:
"""
Get the passage groups for a given query ID.
Args:
qid: The query ID for which the passage groups will be retrieved.
Returns:
A DataFrame with the passage groups for the given query ID.
"""
return self.data.filter(pl.col("qid") == qid).sort("rank")

def from_json(self, path: str):
def _from_json(self, path: str):
"""
Load ranking from a JSON file.
Expand All @@ -69,7 +81,7 @@ def from_json(self, path: str):
"""
self.data = pl.read_ndjson(path, schema={"qid": pl.String, "pid": pl.String, "rank": pl.Int32, "score": pl.Float64})

def from_csv(self, path: str):
def _from_csv(self, path: str):
"""
Load ranking from a CSV file.
Expand All @@ -78,19 +90,7 @@ def from_csv(self, path: str):
"""
self.data = pl.read_csv(path, columns=["qid", "pid", "rank", "score"])

def get_passage_groups(self, qid: str) -> pl.DataFrame:
"""
Get the passage groups for a given query ID.
Args:
qid: The query ID for which the passage groups will be retrieved.
Returns:
A DataFrame with the passage groups for the given query ID.
"""
return self.data.filter(pl.col("qid") == qid).sort("rank")

def from_list(self, ranking: List):
def _from_list(self, ranking: List):
"""
Load ranking from a list.
Expand All @@ -99,7 +99,7 @@ def from_list(self, ranking: List):
"""
self.data = pl.DataFrame(ranking, schema=["qid", "pid", "rank", "score"])

def to_json(self, path: str):
def _to_json(self, path: str):
"""
Save the ranking to a JSON file.
Expand All @@ -108,7 +108,7 @@ def to_json(self, path: str):
"""
self.data.write_ndjson(path)

def to_csv(self, path: str):
def _to_csv(self, path: str):
"""
Save the ranking to a CSV file.
Expand Down
16 changes: 8 additions & 8 deletions pirate/data/triples.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def load(self, triples: Union[str, List[List[str]]]) -> List[List[str]]:
ext = triples.split(".")[-1]

if ext == "json" or ext == "jsonl":
return self.from_json(triples)
return self._from_json(triples)
elif ext == "csv":
return self.from_csv(triples)
return self._from_csv(triples)
else:
raise NotImplementedError(f"Extension {ext} not supported")

Expand All @@ -54,13 +54,13 @@ def save(self, path: str):
ext = path.split(".")[-1]

if ext == "json" or ext == "jsonl":
self.to_json(path)
self._to_json(path)
elif ext == "csv":
self.to_csv(path)
self._to_csv(path)
else:
raise NotImplementedError(f"Extension {ext} not supported")

def from_json(self, path: str) -> List[List[str]]:
def _from_json(self, path: str) -> List[List[str]]:
"""
Load triples from a JSON file.
Expand All @@ -73,7 +73,7 @@ def from_json(self, path: str) -> List[List[str]]:
with open(path, "r") as f:
return [json.loads(line) for line in f]

def from_csv(self, path: str) -> List[List[str]]:
def _from_csv(self, path: str) -> List[List[str]]:
"""
Load triples from a CSV file.
Expand All @@ -86,7 +86,7 @@ def from_csv(self, path: str) -> List[List[str]]:
with open(path, "r") as f:
return [[item.strip() for item in line.split(",")] for line in f]

def to_json(self, path: str) -> None:
def _to_json(self, path: str) -> None:
"""
Save the triples to a JSON file.
Expand All @@ -97,7 +97,7 @@ def to_json(self, path: str) -> None:
for triple in self.triples:
f.write(json.dumps(list(triple)) + "\n")

def to_csv(self, path: str) -> None:
def _to_csv(self, path: str) -> None:
"""
Save the triples to a CSV file.
Expand Down

0 comments on commit 1d46623

Please sign in to comment.