mirror of
https://github.com/shouptech/flask-tutorial.git
synced 2026-02-03 07:29:42 +00:00
use sqlalchemy
This commit is contained in:
parent
c4c170d53d
commit
c3a3eb30e2
6 changed files with 61 additions and 88 deletions
|
|
@ -7,8 +7,11 @@ def create_app(test_config=None):
|
||||||
# create and configure the app
|
# create and configure the app
|
||||||
app = Flask(__name__, instance_relative_config=True)
|
app = Flask(__name__, instance_relative_config=True)
|
||||||
app.config.from_mapping(
|
app.config.from_mapping(
|
||||||
DATABASE=os.path.join(app.instance_path, 'flaskr.sqlite'),
|
SECRET_KEY = os.environ.get("SECRET_KEY", default=None),
|
||||||
SECRET_KEY = os.environ.get("SECRET_KEY", default=None)
|
SQLALCHEMY_TRACK_MODIFICATIONS = os.environ.get(
|
||||||
|
"SQLALCHEMY_TRACK_MODIFICATIONS", default=False),
|
||||||
|
SQLALCHEMY_DATABASE_URI = os.environ.get(
|
||||||
|
"SQLALCHEMY_DATABASE_URI", default=None)
|
||||||
)
|
)
|
||||||
|
|
||||||
if test_config is None:
|
if test_config is None:
|
||||||
|
|
@ -23,7 +26,6 @@ def create_app(test_config=None):
|
||||||
os.makedirs(app.instance_path)
|
os.makedirs(app.instance_path)
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# a simple page that says hello
|
# a simple page that says hello
|
||||||
@app.route('/hello')
|
@app.route('/hello')
|
||||||
def hello():
|
def hello():
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from flask import (
|
||||||
)
|
)
|
||||||
from werkzeug.security import check_password_hash, generate_password_hash
|
from werkzeug.security import check_password_hash, generate_password_hash
|
||||||
|
|
||||||
from flaskr.db import get_db
|
from flaskr.db import db, User
|
||||||
|
|
||||||
bp = Blueprint('auth', __name__, url_prefix='/auth')
|
bp = Blueprint('auth', __name__, url_prefix='/auth')
|
||||||
|
|
||||||
|
|
@ -14,24 +14,18 @@ def register():
|
||||||
if request.method == 'POST':
|
if request.method == 'POST':
|
||||||
username = request.form['username']
|
username = request.form['username']
|
||||||
password = request.form['password']
|
password = request.form['password']
|
||||||
db = get_db()
|
|
||||||
error = None
|
error = None
|
||||||
|
|
||||||
if not username:
|
if not username:
|
||||||
error = 'Username is required.'
|
error = 'Username is required.'
|
||||||
elif not password:
|
elif not password:
|
||||||
error = 'Password is required.'
|
error = 'Password is required.'
|
||||||
elif db.execute(
|
elif User.query.filter_by(username=username).first() is not None:
|
||||||
'SELECT id FROM user WHERE username = ?', (username,)
|
|
||||||
).fetchone() is not None:
|
|
||||||
error = 'User {} is already registered.'.format(username)
|
error = 'User {} is already registered.'.format(username)
|
||||||
|
else:
|
||||||
if error is None:
|
db.session.add(User(username=username,
|
||||||
db.execute(
|
password=generate_password_hash(password)))
|
||||||
'INSERT INTO user (username, password) VALUES (?, ?)',
|
db.session.commit()
|
||||||
(username, generate_password_hash(password))
|
|
||||||
)
|
|
||||||
db.commit()
|
|
||||||
return redirect(url_for('auth.login'))
|
return redirect(url_for('auth.login'))
|
||||||
|
|
||||||
flash(error)
|
flash(error)
|
||||||
|
|
@ -43,20 +37,18 @@ def login():
|
||||||
if request.method == 'POST':
|
if request.method == 'POST':
|
||||||
username = request.form['username']
|
username = request.form['username']
|
||||||
password = request.form['password']
|
password = request.form['password']
|
||||||
db = get_db()
|
|
||||||
error = None
|
error = None
|
||||||
user = db.execute(
|
|
||||||
'SELECT * FROM user WHERE username = ?', (username,)
|
user = User.query.filter_by(username=username).first()
|
||||||
).fetchone()
|
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
error = 'Incorrect username.'
|
error = 'Incorrect username.'
|
||||||
elif not check_password_hash(user['password'], password):
|
elif not check_password_hash(user.password, password):
|
||||||
error = 'Incorrect password.'
|
error = 'Incorrect password.'
|
||||||
|
|
||||||
if error is None:
|
if error is None:
|
||||||
session.clear()
|
session.clear()
|
||||||
session['user_id'] = user['id']
|
session['user_id'] = user.id
|
||||||
return redirect(url_for('index'))
|
return redirect(url_for('index'))
|
||||||
|
|
||||||
flash(error)
|
flash(error)
|
||||||
|
|
@ -70,9 +62,7 @@ def load_logged_in_user():
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
g.user = None
|
g.user = None
|
||||||
else:
|
else:
|
||||||
g.user = get_db().execute(
|
g.user = User.query.filter_by(id=user_id).first()
|
||||||
'SELECT * FROM user WHERE id = ?', (user_id,)
|
|
||||||
).fetchone()
|
|
||||||
|
|
||||||
@bp.route('/logout')
|
@bp.route('/logout')
|
||||||
def logout():
|
def logout():
|
||||||
|
|
|
||||||
|
|
@ -4,18 +4,13 @@ from flask import (
|
||||||
from werkzeug.exceptions import abort
|
from werkzeug.exceptions import abort
|
||||||
|
|
||||||
from flaskr.auth import login_required
|
from flaskr.auth import login_required
|
||||||
from flaskr.db import get_db
|
from flaskr.db import db, Post, User
|
||||||
|
|
||||||
bp = Blueprint('blog', __name__)
|
bp = Blueprint('blog', __name__)
|
||||||
|
|
||||||
@bp.route('/')
|
@bp.route('/')
|
||||||
def index():
|
def index():
|
||||||
db = get_db()
|
posts = Post.query.join(User).order_by(Post.created.desc()).all()
|
||||||
posts = db.execute(
|
|
||||||
'SELECT p.id, title, body, created, author_id, username'
|
|
||||||
' FROM post p JOIN user u ON p.author_id = u.id'
|
|
||||||
' ORDER BY created DESC'
|
|
||||||
).fetchall()
|
|
||||||
return render_template('blog/index.html', posts=posts)
|
return render_template('blog/index.html', posts=posts)
|
||||||
|
|
||||||
@bp.route('/create', methods=('GET', 'POST'))
|
@bp.route('/create', methods=('GET', 'POST'))
|
||||||
|
|
@ -32,29 +27,21 @@ def create():
|
||||||
if error is not None:
|
if error is not None:
|
||||||
flash(error)
|
flash(error)
|
||||||
else:
|
else:
|
||||||
db = get_db()
|
db.session.add(Post(title=title,
|
||||||
db.execute(
|
body=body,
|
||||||
'INSERT INTO post (title, body, author_id)'
|
user_id=g.user.id))
|
||||||
' VALUES (?, ?, ?)',
|
db.session.commit()
|
||||||
(title, body, g.user['id'])
|
|
||||||
)
|
|
||||||
db.commit()
|
|
||||||
return redirect(url_for('blog.index'))
|
return redirect(url_for('blog.index'))
|
||||||
|
|
||||||
return render_template('blog/create.html')
|
return render_template('blog/create.html')
|
||||||
|
|
||||||
def get_post(id, check_author=True):
|
def get_post(id, check_author=True):
|
||||||
post = get_db().execute(
|
post = Post.query.filter_by(id=id).first()
|
||||||
'SELECT p.id, title, body, created, author_id, username'
|
|
||||||
' FROM post p JOIN user u ON p.author_id = u.id'
|
|
||||||
' WHERE p.id = ?',
|
|
||||||
(id,)
|
|
||||||
).fetchone()
|
|
||||||
|
|
||||||
if post is None:
|
if post is None:
|
||||||
abort(404, "Post id {0} doesn't exist.".format(id))
|
abort(404, "Post id {0} doesn't exist.".format(id))
|
||||||
|
|
||||||
if check_author and post['author_id'] != g.user['id']:
|
if check_author and post.user_id != g.user.id:
|
||||||
abort(403)
|
abort(403)
|
||||||
|
|
||||||
return post
|
return post
|
||||||
|
|
@ -75,13 +62,10 @@ def update(id):
|
||||||
if error is not None:
|
if error is not None:
|
||||||
flash(error)
|
flash(error)
|
||||||
else:
|
else:
|
||||||
db = get_db()
|
post = Post.query.filter_by(id=id).first()
|
||||||
db.execute(
|
post.title = title
|
||||||
'UPDATE post SET title = ?, body = ?'
|
post.body = body
|
||||||
' WHERE id = ?',
|
db.session.commit()
|
||||||
(title, body, id)
|
|
||||||
)
|
|
||||||
db.commit()
|
|
||||||
return redirect(url_for('blog.index'))
|
return redirect(url_for('blog.index'))
|
||||||
|
|
||||||
return render_template('blog/update.html', post=post)
|
return render_template('blog/update.html', post=post)
|
||||||
|
|
@ -89,8 +73,6 @@ def update(id):
|
||||||
@bp.route('/<int:id>/delete', methods=('POST',))
|
@bp.route('/<int:id>/delete', methods=('POST',))
|
||||||
@login_required
|
@login_required
|
||||||
def delete(id):
|
def delete(id):
|
||||||
get_post(id)
|
db.session.delete(get_post(id))
|
||||||
db = get_db()
|
db.session.commit()
|
||||||
db.execute('DELETE FROM post WHERE id = ?', (id,))
|
|
||||||
db.commit()
|
|
||||||
return redirect(url_for('blog.index'))
|
return redirect(url_for('blog.index'))
|
||||||
|
|
|
||||||
52
flaskr/db.py
52
flaskr/db.py
|
|
@ -1,41 +1,39 @@
|
||||||
import sqlite3
|
from datetime import datetime
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from flask import current_app, g
|
from flask import current_app, g
|
||||||
from flask.cli import with_appcontext
|
from flask.cli import with_appcontext
|
||||||
|
from flask_sqlalchemy import SQLAlchemy
|
||||||
|
|
||||||
|
db = SQLAlchemy()
|
||||||
def get_db():
|
|
||||||
if 'db' not in g:
|
|
||||||
g.db = sqlite3.connect(
|
|
||||||
current_app.config['DATABASE'],
|
|
||||||
detect_types=sqlite3.PARSE_DECLTYPES
|
|
||||||
)
|
|
||||||
g.db.row_factory = sqlite3.Row
|
|
||||||
|
|
||||||
return g.db
|
|
||||||
|
|
||||||
|
|
||||||
def close_db(e=None):
|
|
||||||
db = g.pop('db', None)
|
|
||||||
|
|
||||||
if db is not None:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
def init_db():
|
|
||||||
db = get_db()
|
|
||||||
|
|
||||||
with current_app.open_resource('schema.sql') as f:
|
|
||||||
db.executescript(f.read().decode('utf8'))
|
|
||||||
|
|
||||||
|
|
||||||
@click.command('init-db')
|
@click.command('init-db')
|
||||||
@with_appcontext
|
@with_appcontext
|
||||||
def init_db_command():
|
def init_db_command():
|
||||||
"""Clear the existing data and create new tables."""
|
"""Clear the existing data and create new tables."""
|
||||||
init_db()
|
db.create_all()
|
||||||
click.echo('Initialized the database.')
|
click.echo('Initialized the database.')
|
||||||
|
|
||||||
def init_app(app):
|
def init_app(app):
|
||||||
app.teardown_appcontext(close_db)
|
db.init_app(app) # Initialize the sql alchemy database
|
||||||
app.cli.add_command(init_db_command)
|
app.cli.add_command(init_db_command)
|
||||||
|
|
||||||
|
# Model definitions
|
||||||
|
class User(db.Model):
|
||||||
|
id = db.Column(db.Integer, primary_key=True)
|
||||||
|
username = db.Column(db.Text, nullable=False)
|
||||||
|
password = db.Column(db.Text, nullable=False)
|
||||||
|
posts = db.relationship('Post', backref='user', lazy=True)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<User %r>' % self.username
|
||||||
|
|
||||||
|
class Post(db.Model):
|
||||||
|
id = db.Column(db.Integer, primary_key=True)
|
||||||
|
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
|
||||||
|
created = db.Column(db.DateTime, nullable=False, default=datetime.utcnow)
|
||||||
|
title = db.Column(db.Text, nullable=False)
|
||||||
|
body = db.Column(db.Text, nullable=False)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<Post %r>' % self.title
|
||||||
|
|
|
||||||
|
|
@ -12,14 +12,14 @@
|
||||||
<article class="post">
|
<article class="post">
|
||||||
<header>
|
<header>
|
||||||
<div>
|
<div>
|
||||||
<h1>{{ post['title'] }}</h1>
|
<h1>{{ post.title }}</h1>
|
||||||
<div class="about">by {{ post['username'] }} on {{ post['created'].strftime('%Y-%m-%d') }}</div>
|
<div class="about">by {{ post.user.username }} on {{ post.created.strftime('%Y-%m-%d') }}</div>
|
||||||
</div>
|
</div>
|
||||||
{% if g.user['id'] == post['author_id'] %}
|
{% if g.user.id == post.user_id %}
|
||||||
<a class="action" href="{{ url_for('blog.update', id=post['id']) }}">Edit</a>
|
<a class="action" href="{{ url_for('blog.update', id=post.id) }}">Edit</a>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
</header>
|
</header>
|
||||||
<p class="body">{{ post['body'] }}</p>
|
<p class="body">{{ post.body }}</p>
|
||||||
</article>
|
</article>
|
||||||
{% if not loop.last %}
|
{% if not loop.last %}
|
||||||
<hr>
|
<hr>
|
||||||
|
|
|
||||||
1
setup.py
1
setup.py
|
|
@ -8,5 +8,6 @@ setup(
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'flask',
|
'flask',
|
||||||
|
'flask_sqlalchemy',
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue