Skip to content

Commit

Permalink
feat: Add AI image generation in /profile
Browse files Browse the repository at this point in the history
  • Loading branch information
seriaati committed Mar 7, 2024
1 parent 0a2bbcc commit 307c4a6
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 4 deletions.
7 changes: 7 additions & 0 deletions src/bot/bot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING, TypeAlias

Expand All @@ -9,6 +10,7 @@
from cachetools import TTLCache
from discord.ext import commands, tasks

from ..hoyo.novelai_client import NAIClient
from ..utils import get_now
from .command_tree import CommandTree
from .translator import AppCommandTranslator, Translator
Expand Down Expand Up @@ -60,6 +62,9 @@ def __init__(
self.translator = translator
self.env = env
self.diskcache = diskcache.Cache("./.cache/hoyo_buddy")
self.nai_client = NAIClient(
token=os.environ["NAI_TOKEN"], host_url=os.environ["NAI_HOST_URL"]
)

async def setup_hook(self) -> None:
await self.tree.set_translator(AppCommandTranslator(self.translator))
Expand All @@ -74,6 +79,8 @@ async def setup_hook(self) -> None:

await self.load_extension("jishaku")

await self.nai_client.init(timeout=120)

self.push_source_strings.start()

def capture_exception(self, e: Exception) -> None:
Expand Down
150 changes: 150 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,153 @@
EnkaLanguage.ITALIAN: "it",
EnkaLanguage.TURKISH: "tr",
}

# NSFW tags for AI art generation feature
NSFW_TAGS: set[str] = {
"nsfw",
"pussy",
"vaginal",
"futanari",
"nipple pull",
"cowgirl position",
"cum on breast",
"happy sex",
"bound arms",
"cock in thighhigh",
"deepthroat",
"fruit insertion",
"transformation",
"mound of venus",
"footjob",
"wakamezake",
"leg lock",
"upright straddle",
"penis",
"thigh sex",
"stomach bulge",
"testicles",
"female ejaculation",
"breast grab",
"spanked",
"mind control",
"cum on food",
"nyotaimori",
"breast feeding",
"double penetration",
"double vaginal",
"tally",
"femdom",
"cum",
"grinding",
"ring gag",
"groping",
"enema",
"tribadism",
"navel piercing",
"buttjob",
"vibrator",
"vibrator in thighhighs",
"masturbation",
"pubic hair",
"cunnilingus",
"clitoris",
"body writing",
"clothed masturbation",
"frogtie",
"facial",
"cameltoe",
"x-ray",
"cross-section",
"internal cumshot",
"exhibitionism",
"bra lift",
"caught",
"blood",
"tapegag",
"tamakeri",
"nipple torture",
"walk-in",
"gangbang",
"anal insertion",
"anus",
"anilingus",
"anal beads",
"voyeurism",
"crotch rope",
"crotch rub",
"rope",
"humiliation",
"group sex",
"orgy",
"teamwork",
"spread ass",
"no pussy",
"underwater sex",
"insertion",
"nipple tweak",
"rape",
"penetration",
"cum on hair",
"bound wrists",
"ejaculation",
"reverse cowgirl",
"lactation",
"breast sucking",
"nipple suck",
"girl on top",
"wide hips",
"large insertion",
"anal fingering",
"bondage",
"bitgag",
"handjob",
"cleave gag",
"bdsm",
"pillory",
"stocks",
"shibari",
"fingering",
"cock ring",
"huge ass",
"censored",
"gokkun",
"vore",
"puffy nipples",
"leash",
"gag",
"ass",
"amputee",
"hitachi magic wand",
"uncensored",
"panty gag",
"fat mons",
"ballgag",
"clothed sex",
"piercing",
"anal",
"pussy juice",
"doggystyle",
"ganguro",
"hogtie",
"wooden horse",
"breast smother",
"fisting",
"suspension",
"anal fisting",
"have to pee",
"peeing",
"small nipples",
"cervix",
"multiple paizuri",
"oral",
"fellatio",
"hairjob",
"virgin",
"facesitting",
"double anal",
"pegging",
"slave",
"about to be raped",
"sex",
"molestation",
}
22 changes: 22 additions & 0 deletions src/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,25 @@ def __init__(self) -> None:
key="verification_code_service_unavailable_description",
),
)


class NSFWPromptError(HoyoBuddyError):
def __init__(self) -> None:
super().__init__(
title=LocaleStr("NSFW Prompt", key="nsfw_prompt_error_title"),
message=LocaleStr(
"The prompt contains NSFW content, please try again with a different prompt.",
key="nsfw_prompt_error_message",
),
)


class GuildOnlyFeatureError(HoyoBuddyError):
def __init__(self) -> None:
super().__init__(
title=LocaleStr("Guild Only Feature", key="guild_only_feature_error_title"),
message=LocaleStr(
"This feature is only available in guilds, please try again in a guild.",
key="guild_only_feature_error_message",
),
)
21 changes: 21 additions & 0 deletions src/hoyo/novelai_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import novelai


class NAIClient(novelai.NAIClient):
def __init__(self, *, token: str, host_url: str, **kwargs) -> None:
novelai.Host.CUSTOM.value.url = host_url
super().__init__(token=token, **kwargs)

async def generate_image(self, prompt: str, negative_prompt: str) -> bytes:
metadata = novelai.Metadata(
prompt=prompt,
negative_prompt=negative_prompt,
res_preset=novelai.Resolution.NORMAL_PORTRAIT,
steps=28,
n_samples=1,
)
images = await super().generate_image(
metadata, host=novelai.Host.CUSTOM, verbose=False, is_opus=False
)
im = images[0]
return im.data
6 changes: 3 additions & 3 deletions src/ui/hoyo/profile/items/add_img_btn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ async def callback(self, i: "INTERACTION") -> None:
await i.response.send_modal(modal)
await modal.wait()

await self.set_loading_state(i)

image_url = modal.image_url.value
if not image_url:
return
Expand All @@ -61,9 +59,11 @@ async def callback(self, i: "INTERACTION") -> None:
if not passed:
raise InvalidImageURLError

await self.set_loading_state(i)

# Upload the image to iili
try:
url = await upload_image(image_url, i.client.session)
url = await upload_image(i.client.session, image_url=image_url)
except Exception as e:
raise InvalidImageURLError from e

Expand Down
3 changes: 3 additions & 0 deletions src/ui/hoyo/profile/items/card_settings_btn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .card_settings_info_btn import CardSettingsInfoButton
from .card_template_select import CardTemplateSelect
from .dark_mode_btn import DarkModeButton
from .gen_ai_art_btn import GenerateAIArtButton
from .image_select import ImageSelect
from .primary_color_btn import PrimaryColorButton
from .remove_img_btn import RemoveImageButton
Expand Down Expand Up @@ -66,6 +67,8 @@ async def callback(self, i: "INTERACTION") -> None:
)
)
self.view.add_item(CardSettingsInfoButton())

self.view.add_item(GenerateAIArtButton())
self.view.add_item(AddImageButton())
self.view.add_item(
RemoveImageButton(self.view._card_settings.current_image in default_arts)
Expand Down
95 changes: 95 additions & 0 deletions src/ui/hoyo/profile/items/gen_ai_art_btn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import TYPE_CHECKING

from discord import ButtonStyle, TextStyle
from discord.file import File

from src.bot.translator import LocaleStr
from src.constants import NSFW_TAGS
from src.exceptions import GuildOnlyFeatureError, NSFWPromptError
from src.ui.components import Button, Modal, TextInput

from .....utils import upload_image

if TYPE_CHECKING:
from src.bot.bot import INTERACTION

from ..view import ProfileView # noqa: F401
from .image_select import ImageSelect
from .remove_img_btn import RemoveImageButton


class GenerateAIArtModal(Modal):
prompt = TextInput(
label=LocaleStr("Prompt", key="profile.generate_ai_art_modal.prompt.label"),
placeholder="navia(genshin impact), foaml dress, idol, beautiful dress, elegant, best quality, aesthetic...",
style=TextStyle.paragraph,
max_length=250,
)

negative_prompt = TextInput(
label=LocaleStr(
"Negative Prompt", key="profile.generate_ai_art_modal.negative_prompt.label"
),
placeholder="bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs...",
style=TextStyle.paragraph,
max_length=200,
required=False,
)


class GenerateAIArtButton(Button):
def __init__(self) -> None:
super().__init__(
label=LocaleStr("Generate AI Art", key="profile.generate_ai_art.button.label"),
style=ButtonStyle.blurple,
row=3,
)

async def callback(self, i: "INTERACTION") -> None:
if i.guild is None:
raise GuildOnlyFeatureError

modal = GenerateAIArtModal(
title=LocaleStr("Generate AI Art", key="profile.generate_ai_art_modal.title")
)
modal.translate(self.view.locale, self.view.translator)
await i.response.send_modal(modal)
await modal.wait()

if not modal.prompt.value:
return

prompt = modal.prompt.value
negative_prompt = modal.negative_prompt.value
if any(tag.lower() in prompt.lower() for tag in NSFW_TAGS):
raise NSFWPromptError

await self.set_loading_state(i)

client = i.client.nai_client
bytes_ = await client.generate_image(prompt, negative_prompt)
url = await upload_image(i.client.session, image=bytes_)

# Add the image URL to db
self.view._card_settings.custom_images.append(url)
self.view._card_settings.current_image = url
await self.view._card_settings.save()

# Add the new image URL to the image select options
image_select: ImageSelect = self.view.get_item("profile_image_select")
image_select.options_before_split = image_select.generate_options()
image_select.options = image_select.process_options()
# Set the new image as the default (selected) option
image_select.update_options_defaults(values=[url])
image_select.translate(self.view.locale, self.view.translator)

# Enable the remove image button
remove_img_btn: RemoveImageButton = self.view.get_item("profile_remove_image")
remove_img_btn.disabled = False

# Redraw the card
bytes_obj = await self.view.draw_card(i)
bytes_obj.seek(0)
await self.unset_loading_state(
i, attachments=[File(bytes_obj, filename="card.webp")], embed=None
)
11 changes: 10 additions & 1 deletion src/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import datetime
import logging
import math
Expand Down Expand Up @@ -72,13 +73,21 @@ def round_down(number: float, decimals: int) -> float:
return result


async def upload_image(image_url: str, session: aiohttp.ClientSession) -> str:
async def upload_image(
session: aiohttp.ClientSession, *, image_url: str | None = None, image: bytes | None = None
) -> str:
api = "https://freeimage.host/api/1/upload"
data = {
"key": "6d207e02198a847aa98d0a2a901485a5",
"source": image_url,
"format": "json",
}

if image is not None:
# Encode image into base64 string
image_base64 = base64.b64encode(image).decode("utf-8")
data["source"] = image_base64

async with session.post(api, data=data) as resp:
resp.raise_for_status()

Expand Down

0 comments on commit 307c4a6

Please sign in to comment.