CS50_Labs/Lab9/finance/application.py

331 lines
10 KiB
Python

import os
from cs50 import SQL
from flask import Flask, flash, redirect, render_template, request, session
from flask_session import Session
from tempfile import mkdtemp
from werkzeug.exceptions import default_exceptions, HTTPException, InternalServerError
from werkzeug.security import check_password_hash, generate_password_hash
from datetime import datetime
from helpers import apology, login_required, lookup, usd
# Configure application
app = Flask(__name__)
# Ensure templates are auto-reloaded
app.config["TEMPLATES_AUTO_RELOAD"] = True
# Ensure responses aren't cached
@app.after_request
def after_request(response):
response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
response.headers["Expires"] = 0
response.headers["Pragma"] = "no-cache"
return response
# Custom filter
app.jinja_env.filters["usd"] = usd
# Configure session to use filesystem (instead of signed cookies)
app.config["SESSION_FILE_DIR"] = mkdtemp()
app.config["SESSION_PERMANENT"] = False
app.config["SESSION_TYPE"] = "filesystem"
Session(app)
# Configure CS50 Library to use SQLite database
db = SQL("sqlite:///finance.db")
# Make sure API key is set
if not os.environ.get("API_KEY"):
raise RuntimeError("API_KEY not set")
@app.route("/")
@login_required
def index():
"""Show portfolio of stocks"""
# getting user info from DB
user_id = session["user_id"]
portfolio = db.execute("SELECT * FROM portfolio WHERE id = ?", user_id)
cash_list = db.execute("SELECT cash FROM users WHERE id = ?", user_id)
cash = cash_list[0]["cash"]
grand_total = cash
# list throught user protfolio and all needed information to the portfolio list
for row in portfolio:
stock = lookup(row["symbol"])
price = float(stock["price"])
name = stock["name"]
amount = int(row["amount"])
total = price * amount
row["name"] = name
row["price"] = usd(price)
row["total"] = usd(total)
grand_total += total
# money formatting and returning html
cash = usd(cash)
grand_total = usd(grand_total)
return render_template("index.html", portfolio=portfolio, cash=cash, grand_total=grand_total)
@app.route("/buy", methods=["GET", "POST"])
@login_required
def buy():
"""Buy shares of stock"""
# User reached route via POST (as by submitting a form via POST)
# get user inputs and check if they are valid
if request.method == "POST":
user_id = session["user_id"]
symbol = request.form.get("symbol")
shares = request.form.get("shares")
if symbol == "":
return apology("Missing symbol")
elif shares == "":
return apology("Missing number of shares")
elif not shares.isdigit() or int(shares) < 1:
return apology("Number of shares must be bigger than 1.")
# getting stock info from API
stock = lookup(symbol)
shares = int(shares)
symbol = symbol.upper()
if not stock:
return apology("Not existing symbol")
# compare total stock price vs funds
price = stock["price"]
total_price = price * shares
row = db.execute("SELECT * FROM users WHERE id = ?", user_id)
funds = float(row[0]["cash"])
if funds < total_price:
return apology("You don't have enough money")
# update users cash, insert transcation to history table
now = datetime.now()
now = now.strftime("%Y-%m-%d %H:%M:%S")
db.execute("UPDATE users SET cash = cash - ? WHERE id = ?", total_price, user_id)
db.execute("INSERT INTO history (id, symbol, shares, price, transacted) VALUES (?, ?, ?, ?, ?)", user_id,
symbol, shares, total_price, now)
# check if user already owns this stock, change portfolio
owned = db.execute("SELECT * FROM portfolio WHERE id = ? AND symbol = ?", user_id, symbol)
if len(owned) > 0:
db.execute("UPDATE portfolio SET amount = amount + ? WHERE id = ? AND symbol = ?", shares, user_id, symbol)
else:
db.execute("INSERT INTO portfolio (id, symbol, amount) VALUES (?, ?, ?)", user_id, symbol, shares)
# Redirect user to home page
return redirect("/")
else:
return render_template("buy.html")
@app.route("/history")
@login_required
def history():
"""Show history of transactions"""
# getting user history logs from DB
user_id = session["user_id"]
history = db.execute("SELECT * FROM history WHERE id = ?", user_id)
return render_template("history.html", history=history)
@app.route("/login", methods=["GET", "POST"])
def login():
"""Log user in"""
# Forget any user_id
session.clear()
# User reached route via POST (as by submitting a form via POST)
if request.method == "POST":
# Ensure username was submitted
if not request.form.get("username"):
return apology("must provide username", 403)
# Ensure password was submitted
elif not request.form.get("password"):
return apology("must provide password", 403)
# Query database for username
rows = db.execute("SELECT * FROM users WHERE username = ?", request.form.get("username"))
# Ensure username exists and password is correct
if len(rows) != 1 or not check_password_hash(rows[0]["hash"], request.form.get("password")):
return apology("invalid username and/or password", 403)
# Remember which user has logged in
session["user_id"] = rows[0]["id"]
# Redirect user to home page
return redirect("/")
# User reached route via GET (as by clicking a link or via redirect)
else:
return render_template("login.html")
@app.route("/logout")
def logout():
"""Log user out"""
# Forget any user_id
session.clear()
# Redirect user to login form
return redirect("/")
@app.route("/quote", methods=["GET", "POST"])
@login_required
def quote():
"""Get stock quote."""
# User reached route via POST (as by submitting a form via POST)
if request.method == "POST":
# getting user input and checking if it's valid
symbol = request.form.get("symbol")
if symbol == "":
return apology("Missing symbol")
# getting stock info from API
stock = lookup(symbol)
if not stock:
return apology("Not existing symbol")
# creating html with info from API
return render_template("quoted.html", symbol=stock["symbol"], name=stock["name"], price=usd(stock["price"]))
else:
return render_template("quote.html")
@app.route("/register", methods=["GET", "POST"])
def register():
"""Register user"""
if request.method == "POST":
# Getting user inputs and checking if he filled all
name = request.form.get("username")
password = request.form.get("password")
confirmation = request.form.get("confirmation")
if name == "":
return apology("Missing username")
elif password == "":
return apology("Missing password")
elif confirmation == "":
return apology("Missing password confirmation")
elif password != confirmation:
return apology("Passwords not matching")
# personal touch - password strength check
upper = 0
digits = 0
symbols = 0
for char in password:
if char.isupper():
upper += 1
elif char.isdigit():
digits += 1
elif not char.islower():
symbols += 1
if upper < 1 or digits < 1 or symbols < 1:
return apology("Password must contain 1 uppercase, 1 number and 1 symbol")
# Checking inputs against DB
rows = db.execute("SELECT * FROM users WHERE username == (?)", name)
if len(rows) > 0:
return apology("Username is already used")
# convert password to hash and insert username + hash to db
hash_pw = generate_password_hash(password)
db.execute("INSERT INTO users (username, hash) VALUES (?, ?)", name, hash_pw)
# remember which user is logged in and redirect to homepage
rows = db.execute("SELECT * FROM users WHERE username = (?)", name)
session["user_id"] = rows[0]["id"]
return redirect("/")
else:
return render_template("register.html")
@app.route("/sell", methods=["GET", "POST"])
@login_required
def sell():
"""Sell shares of stock"""
# getting user portfolio from DB
user_id = session["user_id"]
portfolio = db.execute("SELECT * FROM portfolio WHERE id = ?", user_id)
if request.method == "POST":
# get inputs from user and check if it's valid
symbol = request.form.get("symbol")
shares = int(request.form.get("shares"))
owned_shares = db.execute("SELECT * FROM portfolio WHERE id = ? AND symbol = ?", user_id, symbol)
stock = lookup(symbol)
if shares < 1 or shares == "":
return apology("Invalid number of shares")
# checking difference between owned and want to sell shares
difference = owned_shares[0]["amount"] - shares
if difference < 0:
return apology("You don't own that many shares")
# update user's cash
total_price = stock["price"] * shares
db.execute("UPDATE users SET cash = cash + ? WHERE id = ?", total_price, user_id)
# create entry in history
now = datetime.now()
now = now.strftime("%Y-%m-%d %H:%M:%S")
db.execute("INSERT INTO history (id, symbol, shares, price, transacted) VALUES (?, ?, ?, ?, ?)", user_id,
symbol, shares * (-1), total_price, now)
# if sold amount = owned amount, delete entry from portfolio
if difference == 0:
db.execute("DELETE FROM portfolio WHERE id = ? and symbol = ?", user_id, symbol)
else:
db.execute("UPDATE portfolio SET amount = ? WHERE id = ? and symbol = ?", difference, user_id, symbol)
# Redirect user to home page
return redirect("/")
else:
return render_template("sell.html", portfolio=portfolio)
def errorhandler(e):
"""Handle error"""
if not isinstance(e, HTTPException):
e = InternalServerError()
return apology(e.name, e.code)
# Listen for errors
for code in default_exceptions:
app.errorhandler(code)(errorhandler)