forked from vccheng2001/daily-dose-of-news
-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
88 lines (73 loc) · 2.49 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import json
import replicate
from flask import (
Flask,
jsonify,
render_template,
send_from_directory,
request,
)
import random
from news_api import NewsAPIClient
import os
NEWS_API_TOKEN = os.getenv('NEWS_API_TOKEN')
app = Flask(__name__)
# Render index page
@app.route("/")
def index():
return render_template("index.html")
# Predict
@app.route("/api/predict", methods=["POST"])
def predict():
body = request.get_json()
category = body['category']
country = body['country']
if country == "all":
country = None # all countries
else:
country = country.split(' ')[0] # get country abbreviation
# Get model
print('Category: ', category, 'Country:', country)
print('Fetching model and version......')
model = replicate.models.get("mehdidc/feed_forward_vqgan_clip")
version = model.versions.get(
"28b5242dadb5503688e17738aaee48f5f7f5c0b6e56493d7cf55f74d02f144d8"
)
# Instantiate news API client
print('Instantiating News API Client......')
news_client = NewsAPIClient(NEWS_API_TOKEN = NEWS_API_TOKEN)
print(f'Fetching news headlines for {category} category.......')
result = news_client.get_headlines(category, country)
# Result
if not result:
headline, src, url, description = None, None, None, None
else:
headline, src, url, description = result
print('Processing new headline......', headline)
# Create repliation prediction object
prediction = replicate.predictions.create(
version=version,
input={
"prompt":headline,
"model": 'cc12m_32x1024_mlp_mixer_openclip_laion2b_ViTB32_256x256_v0.4.th',
"prior": False,
"grid": '1x1',
"seed": random.randint(0, 2**15-1),
},
)
return jsonify({"prediction_id": prediction.id, "headline":headline, "src":src, "url":url, "description":description})
# Get prediction by its ID
@app.route("/api/predictions/<prediction_id>", methods=["GET"])
def get_prediction(prediction_id):
prediction = replicate.predictions.get(prediction_id)
output = None
if prediction.output:
print('Prediction output', prediction.output)
import time
time.sleep(5)
return jsonify({"output": prediction.output, "status": prediction.status})
@app.route("/static/<path:path>")
def send_static(path):
return send_from_directory("static", path)
if __name__ == "__main__":
app.run(debug=True)