Published at

JWT Authentication in GraphQL with Python

JWT Authentication in GraphQL with Python

Implement JWT authentication in GraphQL API with Python, covering user registration, login, and securing mutations.

Authors
  • avatar
    Name
    James Lau
    Twitter
  • Indie App Developer at Self-employed
Sharing is caring!
Table of Contents

This blog post guides you through implementing JSON Web Token (JWT) authentication in a GraphQL API using Python. We’ll cover user registration, login, and securing GraphQL mutations with JWTs.

Prerequisites

Ensure you have Python installed. You’ll also need the following packages:

pip install PyJWT bcrypt

Also, you need to have a .env file with SECRET_KEY

SECRET_KEY=2%*Zsr9Myadsfcxt7xVA!iSGhhRX%cTCdiE3i#wLAQBK@rB$EQ3CoVS74gJMPw!*T$AiXXVDpGEsadgasdgq@xiqPTY$gy5Ji4sdsfjSrK3cvwc!$#MwPdwijqwn

1. Project Setup and Dependencies

Start by installing the necessary Python packages. PyJWT is for creating and verifying JWTs, and bcrypt is for securely hashing passwords.

2. Defining User Schema (schemas.py)

Create a schemas.py file to define the User schema using Pydantic’s BaseModel:

from pydantic import BaseModel

class UserSchema(BaseModel):
    username: str
    password: str

3. JWT Token Management (jwt_token.py)

Create a jwt_token.py file to handle JWT creation and decoding:

import os
from datetime import datetime, timedelta

import jwt
from dotenv import load_dotenv

load_dotenv(".env")

secret_key = os.environ["SECRET_KEY"]

algorithm = "HS256"

def create_access_token(data, expires_delta):
    to_encode = data.copy()
    expire = datetime.utcnow() + expires_delta
    to_encode.update({"exp": expire})
    access_token = jwt.encode(to_encode, secret_key, algorithm=algorithm)
    return access_token

def decode_access_token(data):
    token_data = jwt.decode(data, secret_key, algorithms=algorithm)
    return token_data

This code defines two functions:

  • create_access_token: Creates a JWT with an expiration time (60 minutes in this example).
  • decode_access_token: Decodes a JWT to extract the user information.

4. Implementing User Registration and Authentication (main.py)

In your main.py (or wherever your GraphQL mutations are defined), implement the CreateNewUser and AuthenticateUser mutations:

from graphql import GraphQLError
from jwt import PyJWTError
from jwt_token import create_access_token, decode_access_token
from schemas import OHLCVModel, OHLCVSchema, PostModel, PostSchema, UserSchema
import bcrypt
import graphene
from datetime import timedelta
import models
from database import db


class CreateNewUser(graphene.Mutation):
    class Arguments:
        username = graphene.String(required=True)
        password = graphene.String(required=True)

    ok = graphene.Boolean()

    @staticmethod
    def mutate(root, info, username, password):

        hashed_password = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt())

        # https://stackoverflow.com/questions/34548846/flask-bcrypt-valueerror-invalid-salt
        password_hash = hashed_password.decode("utf8")

        user = UserSchema(username=username, password=password_hash)
        db_user = models.User(username=user.username, password=password_hash)
        db.add(db_user)

        # https://docs.sqlalchemy.org/en/13/faq/sessions.html#this-session-s-transaction-has-been-rolled-back-due-to-a-previous-exception-during-flush-or-similar

        try:
            db.commit()
            db.refresh(db_user)
            ok = True
            return CreateNewUser(ok=ok)

        except:
            db.rollback()
            raise

        finally:
            db.close()

class AuthenticateUser(graphene.Mutation):
    class Arguments:
        username = graphene.String(required=True)
        password = graphene.String(required=True)

    ok = graphene.Boolean()
    token = graphene.String()

    @staticmethod
    def mutate(root, info, username, password):

        user = UserSchema(username=username, password=password)
        db_user_info = db.query(models.User).filter(models.User.username == user.username).first()

        if db_user_info is not None:

            if bcrypt.checkpw(user.password.encode("utf-8"), db_user_info.password.encode("utf-8")):
                access_token_expires = timedelta(minutes=60)
                access_token = create_access_token(data={"user": user.username}, expires_delta=access_token_expires)
                ok = True
                return AuthenticateUser(ok=ok, token=access_token)
        else:
            ok = False
            return AuthenticateUser(ok=ok)

class PostMutations(graphene.ObjectType):
    authenticate_user = AuthenticateUser.Field()
    create_new_user = CreateNewUser.Field()

Key points:

  • CreateNewUser hashes the password using bcrypt before saving it to the database.
  • AuthenticateUser retrieves the user from the database and verifies the password using bcrypt.checkpw. If successful, it generates a JWT using create_access_token.
  • The JWT’s expiration time is set to 60 minutes.

5. Securing Mutations with JWT (main.py)

Modify the CreateNewPost mutation to require a JWT for creating a post:

class CreateNewPost(graphene.Mutation):
    class Arguments:
        title = graphene.String(required=True)
        content = graphene.String(required=True)
        token = graphene.String(required=True)

    result = graphene.String()

    @staticmethod
    def mutate(root, info, title, content, token):
        try:
            payload = decode_access_token(data=token)
            username = payload.get("user")

            if username is None:
                raise GraphQLError("Invalid credentials 1")
        except PyJWTError:
            raise GraphQLError("Invalid credentials 2")

        user = db.query(models.User).filter(models.User.username == username).first()

        if user is None:
            raise GraphQLError("Invalid credentials 3")

        post = PostSchema(title=title, content=content)
        db_post = models.Post(title=post.title, content=post.content)
        try:
            db.add(db_post)
            db.commit()
            db.refresh(db_post)
            result = "Added new post"
        except Exception as e:
            print(f"error {e}")
        finally:
            db.close()
        return CreateNewPost(result=result)

Now, the CreateNewPost mutation expects a token argument. It decodes the token using decode_access_token to retrieve the username. It then verifies that the user exists in the database before creating the post.

6. Database Model (models.py)

Here’s a basic SQLAlchemy model for the User:

from sqlalchemy import Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

Base = declarative_base()

class DictMixIn:
    def to_dict(self):
        return {c.key: getattr(self, c.key)
                for c in self.__table__.columns}

class User(Base, DictMixIn):
    __tablename__ = "user"

    id = Column(Integer, primary_key=True, index=True)
    username = Column(String, unique=True)
    password = Column(String(255))

Conclusion

This blog post demonstrated how to implement JWT authentication in a GraphQL API using Python. You can extend this example by adding features like token refresh, role-based authorization, and more robust error handling. Remember to handle your SECRET_KEY securely in a production environment.

Sharing is caring!