-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
107 lines (87 loc) · 3.67 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import streamlit as st
import os
import torch
import nltk
import urllib.request
from models.model_builder import ExtSummarizer
from newspaper import Article
from ext_sum import summarize
from fpdf import FPDF
def main():
st.markdown("<h1 style='text-align: center; color:Red;'>Deep Learning Based Text-Summariser</h1>", unsafe_allow_html=True)
# Download model
if not os.path.exists('checkpoints/mobilebert_ext.pt'):
download_model()
# Load model
model = load_model('mobilebert')
# Input
input_type = st.radio("Input Type: ", ["URL", "Raw Text"])
st.markdown("<h3 style='text-align: center;'>Input</h3>", unsafe_allow_html=True)
if input_type == "Raw Text":
with open("input.txt") as f:
sample_text = f.read()
text = st.text_area("", sample_text, 200)
else:
url = st.text_input("", "https://www.cnn.com/2020/05/29/tech/facebook-violence-trump/index.html")
st.markdown(f"[*Read Original News*]({url})")
text = crawl_url(url)
input_fp = "input.txt"
with open(input_fp, 'w') as file:
file.write(text)
# Summarize
sum_level = st.radio("Output Length: ", ["Short", "Medium"])
max_length = 3 if sum_level == "Short" else 5
result_fp = 'summary.txt'
summary = summarize(input_fp, result_fp, model, max_length=max_length)
st.markdown("<h3 style='text-align: center;'>Summary</h3>", unsafe_allow_html=True)
st.markdown(f"<p align='justify'>{summary}</p>", unsafe_allow_html=True)
#download as pdf
#pdf = FPDF()
# pdf.set_font("Times new roman", size=12)
# f = open("summary.txt", r)
#for x in f:
# pdf.cell(200, 10, txt = x, ln = 1, align = 'C')
# save the pdf with name .pdf
# pdf.output("Sahreen's Summariser.pdf")
def download_model():
nltk.download('popular')
url = 'https://www.googleapis.com/drive/v3/files/1umMOXoueo38zID_AKFSIOGxG9XjS5hDC?alt=media&key=AIzaSyCmo6sAQ37OK8DK4wnT94PoLx5lx-7VTDE'
# These are handles to two visual elements to animate.
weights_warning, progress_bar = None, None
try:
weights_warning = st.warning("Downloading checkpoint...")
progress_bar = st.progress(0)
with open('checkpoints/mobilebert_ext.pt', 'wb') as output_file:
with urllib.request.urlopen(url) as response:
length = int(response.info()["Content-Length"])
counter = 0.0
MEGABYTES = 2.0 ** 20.0
while True:
data = response.read(8192)
if not data:
break
counter += len(data)
output_file.write(data)
# We perform animation by overwriting the elements.
weights_warning.warning("Downloading checkpoint... (%6.2f/%6.2f MB)" %
(counter / MEGABYTES, length / MEGABYTES))
progress_bar.progress(min(counter / length, 1.0))
# Finally, we remove these visual elements by calling .empty().
finally:
if weights_warning is not None:
weights_warning.empty()
if progress_bar is not None:
progress_bar.empty()
@st.cache(suppress_st_warning=True)
def load_model(model_type):
checkpoint = torch.load(f'checkpoints/{model_type}_ext.pt', map_location='cpu')
model = ExtSummarizer(device="cpu", checkpoint=checkpoint, bert_type=model_type)
return model
def crawl_url(url):
article = Article(url)
article.download()
article.parse()
return article.text
st.markdown("<h3 style='text-align: center; bottom:0px;'>Developed By SAHREEN HAIDER ✏️</h3>", unsafe_allow_html=True)
if __name__ == "__main__":
main()