Skip to content

Commit 833730a

Browse files
committed
refactor(scraper): make code more readable
1 parent b157107 commit 833730a

File tree

3 files changed

+63
-50
lines changed

3 files changed

+63
-50
lines changed

npiai/tools/scrapers/base.py

+54-43
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from .types import (
2626
Column,
2727
SourceItem,
28-
SummaryItem,
29-
SummaryChunk,
28+
Row,
29+
RowBatch,
3030
)
3131

3232
__INDEX_COLUMN__ = Column(
@@ -44,10 +44,12 @@ class BaseScraper(FunctionTool, ABC):
4444
infer_prompt: str = DEFAULT_COLUMN_INFERENCE_PROMPT
4545

4646
@abstractmethod
47-
async def init_data(self, ctx: Context): ...
47+
async def init_data(self, ctx: Context):
48+
...
4849

4950
@abstractmethod
50-
async def next_items(self, ctx: Context, count: int) -> List[SourceItem] | None: ...
51+
async def next_items(self, ctx: Context, count: int) -> List[SourceItem] | None:
52+
...
5153

5254
async def summarize_stream(
5355
self,
@@ -56,101 +58,110 @@ async def summarize_stream(
5658
batch_size: int = 1,
5759
limit: int = -1,
5860
concurrency: int = 1,
59-
) -> AsyncGenerator[SummaryChunk, None]:
61+
row_offset: int = 0,
62+
) -> AsyncGenerator[RowBatch, None]:
6063
"""
6164
Summarize the content of a webpage into a csv table represented as a stream of item objects.
6265
6366
Args:
67+
row_offset: row offset of the first batch in the entire task
6468
ctx: NPi context.
6569
output_columns: The columns of the output table. If not provided, use the `infer_columns` function to infer the columns.
66-
batch_size: The number of items to summarize in each batch. Default is 1.
67-
limit: The maximum number of items to summarize. If -1, all items are summarized.
70+
batch_size: The number of rows to summarize in each batch. Default is 1.
71+
limit: The maximum number of rows to summarize. If -1, all rows are summarized.
6872
concurrency: The number of concurrent tasks to run. Default is 1.
6973
7074
Returns:
71-
A stream of items. Each item is a dictionary with keys corresponding to the column names and values corresponding to the column values.
75+
A stream of rows. Each item is a dictionary with keys corresponding to the column names and values corresponding to the column values.
7276
"""
7377
if limit == 0:
7478
return
7579

7680
await self.init_data(ctx)
7781

78-
# total items summarized
79-
count = 0
80-
# remaining items to summarize, excluding the items being summarized
81-
remaining = limit
82-
# batch index
83-
batch_index = 0
82+
total_row_summarized = 0
83+
# remaining rows to summarize, excluding the rows being summarized
84+
remaining_rows = limit
85+
batch_no = 0
8486

8587
lock = asyncio.Lock()
8688

87-
no_count_index = 0
89+
row_number_count = 0
8890

89-
async def run_batch(results_queue: asyncio.Queue[SummaryChunk]):
90-
nonlocal count, no_count_index, remaining, batch_index
91+
# TODO
92+
# 1. one task for retrieve html items
93+
# 2. one task for summarize html items
9194

92-
if limit != -1 and remaining <= 0:
95+
async def run_batch(results_queue: asyncio.Queue[RowBatch]):
96+
nonlocal total_row_summarized, row_number_count, remaining_rows, batch_no
97+
98+
if limit != -1 and remaining_rows <= 0:
9399
return
94100

95101
async with lock:
96-
current_index = batch_index
97-
batch_index += 1
102+
current_batch = batch_no
103+
batch_no += 1
98104

99-
# calculate the number of items to summarize in the current batch
105+
# calculate the number of rows to summarize in the current batch
100106
requested_count = (
101-
min(batch_size, remaining) if limit != -1 else batch_size
107+
min(batch_size, remaining_rows) if limit != -1 else batch_size
102108
)
103-
# reduce the remaining count by the number of items in the current batch
109+
# reduce the remaining count by the number of rows in the current batch
104110
# so that the other tasks will not exceed the limit
105-
remaining -= requested_count
111+
remaining_rows -= requested_count
106112

107113
data = await self.next_items(ctx=ctx, count=requested_count)
108114

109115
if not data:
110-
await ctx.send_debug_message(f"[{self.name}] No more items found")
116+
await ctx.send_debug_message(f"[{self.name}] No more rows found")
111117
return
112118

113119
# await ctx.send_debug_message(
114120
# f"[{self.name}] Parsed markdown: {parsed_result.markdown}"
115121
# )
116122

117123
async with lock:
118-
no_index = no_count_index
119-
no_count_index += len(data)
124+
current_batch_row_number_offset = row_number_count
125+
row_number_count += len(data)
120126

121-
items = await self._summarize_llm_call(
127+
rows = await self._summarize_llm_call(
122128
ctx=ctx,
123129
items=data,
124130
output_columns=output_columns,
125131
)
126132

127-
await ctx.send_debug_message(f"[{self.name}] Summarized {len(items)} items")
133+
await ctx.send_debug_message(f"[{self.name}] Summarized {len(rows)} rows")
128134
#
129-
# if not items:
130-
# await ctx.send_debug_message(f"[{self.name}] No items summarized")
135+
# if not rows:
136+
# await ctx.send_debug_message(f"[{self.name}] No rows summarized")
131137
# return
132138

133139
async with lock:
134-
items_slice = items[:requested_count] if limit != -1 else items
140+
items_slice = rows[:requested_count] if limit != -1 else rows
135141
summarized_count = len(items_slice)
136-
count += summarized_count
137-
# recalculate the remaining count in case summary returned fewer items than requested
142+
total_row_summarized += summarized_count
143+
# recalculate the remaining count in case summary returned fewer rows than requested
138144
if summarized_count < requested_count:
139-
remaining += requested_count - summarized_count
145+
remaining_rows += requested_count - summarized_count
146+
147+
count = 1
148+
for row in items_slice:
149+
row["row_no"] = current_batch_row_number_offset + row_offset + count
150+
count += 1
140151

141152
await results_queue.put(
142-
SummaryChunk(
143-
index=no_index,
144-
batch_id=current_index,
153+
RowBatch(
154+
offset=current_batch_row_number_offset + row_offset,
155+
batch_id=current_batch,
145156
items=items_slice,
146157
)
147158
)
148159

149160
await ctx.send_debug_message(
150-
f"[{self.name}] Summarized {count} items in total"
161+
f"[{self.name}] Summarized {total_row_summarized} rows in total"
151162
)
152163

153-
if limit == -1 or remaining > 0:
164+
if limit == -1 or remaining_rows > 0:
154165
await run_batch(results_queue)
155166

156167
async for chunk in concurrent_task_runner(run_batch, concurrency):
@@ -268,7 +279,7 @@ async def _summarize_llm_call(
268279
ctx: Context,
269280
items: List[SourceItem],
270281
output_columns: List[Column],
271-
) -> List[SummaryItem]:
282+
) -> List[Row]:
272283
"""
273284
Summarize the content of a webpage into a table using LLM.
274285
@@ -309,9 +320,9 @@ async def _summarize_llm_call(
309320
async for row in llm_summarize(ctx.llm, messages):
310321
index = int(row.pop(__INDEX_COLUMN__["name"]))
311322
results.append(
312-
SummaryItem(
323+
Row(
313324
hash=items[index]["hash"],
314-
index=index,
325+
original_data_index=index,
315326
values=row,
316327
)
317328
)

npiai/tools/scrapers/types.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ class SourceItem(TypedDict):
1616
data: Any
1717

1818

19-
class SummaryItem(TypedDict):
19+
class Row(TypedDict):
2020
hash: str
21-
index: int
21+
original_data_index: int
22+
row_no: int
2223
values: Dict[str, str]
2324

2425

25-
class SummaryChunk(TypedDict):
26-
index: int
26+
class RowBatch(TypedDict):
27+
# row offset of this batch in the entire task
28+
offset: int
2729
batch_id: int
28-
items: List[SummaryItem]
30+
items: List[Row]

npiai/tools/scrapers/web/presets/linkedin/posts_scraper.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from npiai.core import PlaywrightContext
1111
from npiai.error import UnauthorizedError
1212
from npiai.tools.scrapers.web import WebScraper
13-
from npiai.tools.scrapers import SummaryChunk, Column
13+
from npiai.tools.scrapers import RowBatch, Column
1414
from npiai.utils.html_to_markdown import CompactMarkdownConverter
1515
from .columns import POST_COLUMNS
1616

@@ -98,7 +98,7 @@ async def summarize_stream(
9898
batch_size: int = 1,
9999
limit: int = -1,
100100
concurrency: int = 1,
101-
) -> AsyncGenerator[SummaryChunk, None]:
101+
) -> AsyncGenerator[RowBatch, None]:
102102
stream = super().summarize_stream(
103103
ctx=ctx,
104104
output_columns=output_columns or POST_COLUMNS,

0 commit comments

Comments
 (0)