331 lines
10 KiB
Python
331 lines
10 KiB
Python
|
import os
|
||
|
|
||
|
from cs50 import SQL
|
||
|
from flask import Flask, flash, redirect, render_template, request, session
|
||
|
from flask_session import Session
|
||
|
from tempfile import mkdtemp
|
||
|
from werkzeug.exceptions import default_exceptions, HTTPException, InternalServerError
|
||
|
from werkzeug.security import check_password_hash, generate_password_hash
|
||
|
from datetime import datetime
|
||
|
|
||
|
from helpers import apology, login_required, lookup, usd
|
||
|
|
||
|
# Configure application
|
||
|
app = Flask(__name__)
|
||
|
|
||
|
# Ensure templates are auto-reloaded
|
||
|
app.config["TEMPLATES_AUTO_RELOAD"] = True
|
||
|
|
||
|
|
||
|
# Ensure responses aren't cached
|
||
|
@app.after_request
|
||
|
def after_request(response):
|
||
|
response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
|
||
|
response.headers["Expires"] = 0
|
||
|
response.headers["Pragma"] = "no-cache"
|
||
|
return response
|
||
|
|
||
|
|
||
|
# Custom filter
|
||
|
app.jinja_env.filters["usd"] = usd
|
||
|
|
||
|
# Configure session to use filesystem (instead of signed cookies)
|
||
|
app.config["SESSION_FILE_DIR"] = mkdtemp()
|
||
|
app.config["SESSION_PERMANENT"] = False
|
||
|
app.config["SESSION_TYPE"] = "filesystem"
|
||
|
Session(app)
|
||
|
|
||
|
# Configure CS50 Library to use SQLite database
|
||
|
db = SQL("sqlite:///finance.db")
|
||
|
|
||
|
# Make sure API key is set
|
||
|
if not os.environ.get("API_KEY"):
|
||
|
raise RuntimeError("API_KEY not set")
|
||
|
|
||
|
|
||
|
@app.route("/")
|
||
|
@login_required
|
||
|
def index():
|
||
|
"""Show portfolio of stocks"""
|
||
|
# getting user info from DB
|
||
|
user_id = session["user_id"]
|
||
|
portfolio = db.execute("SELECT * FROM portfolio WHERE id = ?", user_id)
|
||
|
cash_list = db.execute("SELECT cash FROM users WHERE id = ?", user_id)
|
||
|
cash = cash_list[0]["cash"]
|
||
|
grand_total = cash
|
||
|
|
||
|
# list throught user protfolio and all needed information to the portfolio list
|
||
|
for row in portfolio:
|
||
|
stock = lookup(row["symbol"])
|
||
|
price = float(stock["price"])
|
||
|
name = stock["name"]
|
||
|
amount = int(row["amount"])
|
||
|
total = price * amount
|
||
|
|
||
|
row["name"] = name
|
||
|
row["price"] = usd(price)
|
||
|
row["total"] = usd(total)
|
||
|
|
||
|
grand_total += total
|
||
|
|
||
|
# money formatting and returning html
|
||
|
cash = usd(cash)
|
||
|
grand_total = usd(grand_total)
|
||
|
return render_template("index.html", portfolio=portfolio, cash=cash, grand_total=grand_total)
|
||
|
|
||
|
|
||
|
@app.route("/buy", methods=["GET", "POST"])
|
||
|
@login_required
|
||
|
def buy():
|
||
|
"""Buy shares of stock"""
|
||
|
# User reached route via POST (as by submitting a form via POST)
|
||
|
# get user inputs and check if they are valid
|
||
|
if request.method == "POST":
|
||
|
user_id = session["user_id"]
|
||
|
symbol = request.form.get("symbol")
|
||
|
shares = request.form.get("shares")
|
||
|
|
||
|
if symbol == "":
|
||
|
return apology("Missing symbol")
|
||
|
|
||
|
elif shares == "":
|
||
|
return apology("Missing number of shares")
|
||
|
|
||
|
elif not shares.isdigit() or int(shares) < 1:
|
||
|
return apology("Number of shares must be bigger than 1.")
|
||
|
|
||
|
# getting stock info from API
|
||
|
stock = lookup(symbol)
|
||
|
shares = int(shares)
|
||
|
symbol = symbol.upper()
|
||
|
|
||
|
if not stock:
|
||
|
return apology("Not existing symbol")
|
||
|
|
||
|
# compare total stock price vs funds
|
||
|
price = stock["price"]
|
||
|
total_price = price * shares
|
||
|
row = db.execute("SELECT * FROM users WHERE id = ?", user_id)
|
||
|
funds = float(row[0]["cash"])
|
||
|
|
||
|
if funds < total_price:
|
||
|
return apology("You don't have enough money")
|
||
|
|
||
|
# update users cash, insert transcation to history table
|
||
|
now = datetime.now()
|
||
|
now = now.strftime("%Y-%m-%d %H:%M:%S")
|
||
|
db.execute("UPDATE users SET cash = cash - ? WHERE id = ?", total_price, user_id)
|
||
|
db.execute("INSERT INTO history (id, symbol, shares, price, transacted) VALUES (?, ?, ?, ?, ?)", user_id,
|
||
|
symbol, shares, total_price, now)
|
||
|
|
||
|
# check if user already owns this stock, change portfolio
|
||
|
owned = db.execute("SELECT * FROM portfolio WHERE id = ? AND symbol = ?", user_id, symbol)
|
||
|
if len(owned) > 0:
|
||
|
db.execute("UPDATE portfolio SET amount = amount + ? WHERE id = ? AND symbol = ?", shares, user_id, symbol)
|
||
|
else:
|
||
|
db.execute("INSERT INTO portfolio (id, symbol, amount) VALUES (?, ?, ?)", user_id, symbol, shares)
|
||
|
|
||
|
# Redirect user to home page
|
||
|
return redirect("/")
|
||
|
|
||
|
else:
|
||
|
return render_template("buy.html")
|
||
|
|
||
|
|
||
|
@app.route("/history")
|
||
|
@login_required
|
||
|
def history():
|
||
|
"""Show history of transactions"""
|
||
|
# getting user history logs from DB
|
||
|
user_id = session["user_id"]
|
||
|
history = db.execute("SELECT * FROM history WHERE id = ?", user_id)
|
||
|
|
||
|
return render_template("history.html", history=history)
|
||
|
|
||
|
|
||
|
@app.route("/login", methods=["GET", "POST"])
|
||
|
def login():
|
||
|
"""Log user in"""
|
||
|
|
||
|
# Forget any user_id
|
||
|
session.clear()
|
||
|
|
||
|
# User reached route via POST (as by submitting a form via POST)
|
||
|
if request.method == "POST":
|
||
|
|
||
|
# Ensure username was submitted
|
||
|
if not request.form.get("username"):
|
||
|
return apology("must provide username", 403)
|
||
|
|
||
|
# Ensure password was submitted
|
||
|
elif not request.form.get("password"):
|
||
|
return apology("must provide password", 403)
|
||
|
|
||
|
# Query database for username
|
||
|
rows = db.execute("SELECT * FROM users WHERE username = ?", request.form.get("username"))
|
||
|
|
||
|
# Ensure username exists and password is correct
|
||
|
if len(rows) != 1 or not check_password_hash(rows[0]["hash"], request.form.get("password")):
|
||
|
return apology("invalid username and/or password", 403)
|
||
|
|
||
|
# Remember which user has logged in
|
||
|
session["user_id"] = rows[0]["id"]
|
||
|
|
||
|
# Redirect user to home page
|
||
|
return redirect("/")
|
||
|
|
||
|
# User reached route via GET (as by clicking a link or via redirect)
|
||
|
else:
|
||
|
return render_template("login.html")
|
||
|
|
||
|
|
||
|
@app.route("/logout")
|
||
|
def logout():
|
||
|
"""Log user out"""
|
||
|
|
||
|
# Forget any user_id
|
||
|
session.clear()
|
||
|
|
||
|
# Redirect user to login form
|
||
|
return redirect("/")
|
||
|
|
||
|
|
||
|
@app.route("/quote", methods=["GET", "POST"])
|
||
|
@login_required
|
||
|
def quote():
|
||
|
"""Get stock quote."""
|
||
|
# User reached route via POST (as by submitting a form via POST)
|
||
|
if request.method == "POST":
|
||
|
# getting user input and checking if it's valid
|
||
|
symbol = request.form.get("symbol")
|
||
|
|
||
|
if symbol == "":
|
||
|
return apology("Missing symbol")
|
||
|
|
||
|
# getting stock info from API
|
||
|
stock = lookup(symbol)
|
||
|
|
||
|
if not stock:
|
||
|
return apology("Not existing symbol")
|
||
|
|
||
|
# creating html with info from API
|
||
|
return render_template("quoted.html", symbol=stock["symbol"], name=stock["name"], price=usd(stock["price"]))
|
||
|
|
||
|
else:
|
||
|
return render_template("quote.html")
|
||
|
|
||
|
|
||
|
@app.route("/register", methods=["GET", "POST"])
|
||
|
def register():
|
||
|
"""Register user"""
|
||
|
if request.method == "POST":
|
||
|
|
||
|
# Getting user inputs and checking if he filled all
|
||
|
name = request.form.get("username")
|
||
|
password = request.form.get("password")
|
||
|
confirmation = request.form.get("confirmation")
|
||
|
|
||
|
if name == "":
|
||
|
return apology("Missing username")
|
||
|
|
||
|
elif password == "":
|
||
|
return apology("Missing password")
|
||
|
|
||
|
elif confirmation == "":
|
||
|
return apology("Missing password confirmation")
|
||
|
|
||
|
elif password != confirmation:
|
||
|
return apology("Passwords not matching")
|
||
|
|
||
|
# personal touch - password strength check
|
||
|
upper = 0
|
||
|
digits = 0
|
||
|
symbols = 0
|
||
|
for char in password:
|
||
|
if char.isupper():
|
||
|
upper += 1
|
||
|
elif char.isdigit():
|
||
|
digits += 1
|
||
|
elif not char.islower():
|
||
|
symbols += 1
|
||
|
|
||
|
if upper < 1 or digits < 1 or symbols < 1:
|
||
|
return apology("Password must contain 1 uppercase, 1 number and 1 symbol")
|
||
|
|
||
|
# Checking inputs against DB
|
||
|
rows = db.execute("SELECT * FROM users WHERE username == (?)", name)
|
||
|
if len(rows) > 0:
|
||
|
return apology("Username is already used")
|
||
|
|
||
|
# convert password to hash and insert username + hash to db
|
||
|
hash_pw = generate_password_hash(password)
|
||
|
db.execute("INSERT INTO users (username, hash) VALUES (?, ?)", name, hash_pw)
|
||
|
|
||
|
# remember which user is logged in and redirect to homepage
|
||
|
rows = db.execute("SELECT * FROM users WHERE username = (?)", name)
|
||
|
session["user_id"] = rows[0]["id"]
|
||
|
|
||
|
return redirect("/")
|
||
|
|
||
|
else:
|
||
|
return render_template("register.html")
|
||
|
|
||
|
|
||
|
@app.route("/sell", methods=["GET", "POST"])
|
||
|
@login_required
|
||
|
def sell():
|
||
|
"""Sell shares of stock"""
|
||
|
# getting user portfolio from DB
|
||
|
user_id = session["user_id"]
|
||
|
portfolio = db.execute("SELECT * FROM portfolio WHERE id = ?", user_id)
|
||
|
|
||
|
if request.method == "POST":
|
||
|
# get inputs from user and check if it's valid
|
||
|
symbol = request.form.get("symbol")
|
||
|
shares = int(request.form.get("shares"))
|
||
|
owned_shares = db.execute("SELECT * FROM portfolio WHERE id = ? AND symbol = ?", user_id, symbol)
|
||
|
stock = lookup(symbol)
|
||
|
|
||
|
if shares < 1 or shares == "":
|
||
|
return apology("Invalid number of shares")
|
||
|
|
||
|
# checking difference between owned and want to sell shares
|
||
|
difference = owned_shares[0]["amount"] - shares
|
||
|
|
||
|
if difference < 0:
|
||
|
return apology("You don't own that many shares")
|
||
|
|
||
|
# update user's cash
|
||
|
total_price = stock["price"] * shares
|
||
|
db.execute("UPDATE users SET cash = cash + ? WHERE id = ?", total_price, user_id)
|
||
|
|
||
|
# create entry in history
|
||
|
now = datetime.now()
|
||
|
now = now.strftime("%Y-%m-%d %H:%M:%S")
|
||
|
db.execute("INSERT INTO history (id, symbol, shares, price, transacted) VALUES (?, ?, ?, ?, ?)", user_id,
|
||
|
symbol, shares * (-1), total_price, now)
|
||
|
|
||
|
# if sold amount = owned amount, delete entry from portfolio
|
||
|
if difference == 0:
|
||
|
db.execute("DELETE FROM portfolio WHERE id = ? and symbol = ?", user_id, symbol)
|
||
|
else:
|
||
|
db.execute("UPDATE portfolio SET amount = ? WHERE id = ? and symbol = ?", difference, user_id, symbol)
|
||
|
|
||
|
# Redirect user to home page
|
||
|
return redirect("/")
|
||
|
|
||
|
else:
|
||
|
return render_template("sell.html", portfolio=portfolio)
|
||
|
|
||
|
|
||
|
def errorhandler(e):
|
||
|
"""Handle error"""
|
||
|
if not isinstance(e, HTTPException):
|
||
|
e = InternalServerError()
|
||
|
return apology(e.name, e.code)
|
||
|
|
||
|
|
||
|
# Listen for errors
|
||
|
for code in default_exceptions:
|
||
|
app.errorhandler(code)(errorhandler)
|