diff --git a/src/bot/bot.py b/src/bot/bot.py index 93512b70..a9b811c5 100644 --- a/src/bot/bot.py +++ b/src/bot/bot.py @@ -1,4 +1,5 @@ import logging +import os from pathlib import Path from typing import TYPE_CHECKING, TypeAlias @@ -9,6 +10,7 @@ from cachetools import TTLCache from discord.ext import commands, tasks +from ..hoyo.novelai_client import NAIClient from ..utils import get_now from .command_tree import CommandTree from .translator import AppCommandTranslator, Translator @@ -60,6 +62,9 @@ def __init__( self.translator = translator self.env = env self.diskcache = diskcache.Cache("./.cache/hoyo_buddy") + self.nai_client = NAIClient( + token=os.environ["NAI_TOKEN"], host_url=os.environ["NAI_HOST_URL"] + ) async def setup_hook(self) -> None: await self.tree.set_translator(AppCommandTranslator(self.translator)) @@ -74,6 +79,8 @@ async def setup_hook(self) -> None: await self.load_extension("jishaku") + await self.nai_client.init(timeout=120) + self.push_source_strings.start() def capture_exception(self, e: Exception) -> None: diff --git a/src/constants.py b/src/constants.py index 1c7608b9..dd90ec2d 100644 --- a/src/constants.py +++ b/src/constants.py @@ -169,3 +169,153 @@ EnkaLanguage.ITALIAN: "it", EnkaLanguage.TURKISH: "tr", } + +# NSFW tags for AI art generation feature +NSFW_TAGS: set[str] = { + "nsfw", + "pussy", + "vaginal", + "futanari", + "nipple pull", + "cowgirl position", + "cum on breast", + "happy sex", + "bound arms", + "cock in thighhigh", + "deepthroat", + "fruit insertion", + "transformation", + "mound of venus", + "footjob", + "wakamezake", + "leg lock", + "upright straddle", + "penis", + "thigh sex", + "stomach bulge", + "testicles", + "female ejaculation", + "breast grab", + "spanked", + "mind control", + "cum on food", + "nyotaimori", + "breast feeding", + "double penetration", + "double vaginal", + "tally", + "femdom", + "cum", + "grinding", + "ring gag", + "groping", + "enema", + "tribadism", + "navel piercing", + "buttjob", + "vibrator", + "vibrator in thighhighs", + "masturbation", + "pubic hair", + "cunnilingus", + "clitoris", + "body writing", + "clothed masturbation", + "frogtie", + "facial", + "cameltoe", + "x-ray", + "cross-section", + "internal cumshot", + "exhibitionism", + "bra lift", + "caught", + "blood", + "tapegag", + "tamakeri", + "nipple torture", + "walk-in", + "gangbang", + "anal insertion", + "anus", + "anilingus", + "anal beads", + "voyeurism", + "crotch rope", + "crotch rub", + "rope", + "humiliation", + "group sex", + "orgy", + "teamwork", + "spread ass", + "no pussy", + "underwater sex", + "insertion", + "nipple tweak", + "rape", + "penetration", + "cum on hair", + "bound wrists", + "ejaculation", + "reverse cowgirl", + "lactation", + "breast sucking", + "nipple suck", + "girl on top", + "wide hips", + "large insertion", + "anal fingering", + "bondage", + "bitgag", + "handjob", + "cleave gag", + "bdsm", + "pillory", + "stocks", + "shibari", + "fingering", + "cock ring", + "huge ass", + "censored", + "gokkun", + "vore", + "puffy nipples", + "leash", + "gag", + "ass", + "amputee", + "hitachi magic wand", + "uncensored", + "panty gag", + "fat mons", + "ballgag", + "clothed sex", + "piercing", + "anal", + "pussy juice", + "doggystyle", + "ganguro", + "hogtie", + "wooden horse", + "breast smother", + "fisting", + "suspension", + "anal fisting", + "have to pee", + "peeing", + "small nipples", + "cervix", + "multiple paizuri", + "oral", + "fellatio", + "hairjob", + "virgin", + "facesitting", + "double anal", + "pegging", + "slave", + "about to be raped", + "sex", + "molestation", +} diff --git a/src/exceptions.py b/src/exceptions.py index ae50ef8a..ec3a05a7 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -128,3 +128,25 @@ def __init__(self) -> None: key="verification_code_service_unavailable_description", ), ) + + +class NSFWPromptError(HoyoBuddyError): + def __init__(self) -> None: + super().__init__( + title=LocaleStr("NSFW Prompt", key="nsfw_prompt_error_title"), + message=LocaleStr( + "The prompt contains NSFW content, please try again with a different prompt.", + key="nsfw_prompt_error_message", + ), + ) + + +class GuildOnlyFeatureError(HoyoBuddyError): + def __init__(self) -> None: + super().__init__( + title=LocaleStr("Guild Only Feature", key="guild_only_feature_error_title"), + message=LocaleStr( + "This feature is only available in guilds, please try again in a guild.", + key="guild_only_feature_error_message", + ), + ) diff --git a/src/hoyo/novelai_client.py b/src/hoyo/novelai_client.py new file mode 100644 index 00000000..69416bc1 --- /dev/null +++ b/src/hoyo/novelai_client.py @@ -0,0 +1,21 @@ +import novelai + + +class NAIClient(novelai.NAIClient): + def __init__(self, *, token: str, host_url: str, **kwargs) -> None: + novelai.Host.CUSTOM.value.url = host_url + super().__init__(token=token, **kwargs) + + async def generate_image(self, prompt: str, negative_prompt: str) -> bytes: + metadata = novelai.Metadata( + prompt=prompt, + negative_prompt=negative_prompt, + res_preset=novelai.Resolution.NORMAL_PORTRAIT, + steps=28, + n_samples=1, + ) + images = await super().generate_image( + metadata, host=novelai.Host.CUSTOM, verbose=False, is_opus=False + ) + im = images[0] + return im.data diff --git a/src/ui/hoyo/profile/items/add_img_btn.py b/src/ui/hoyo/profile/items/add_img_btn.py index c7fe9e03..af3f17eb 100644 --- a/src/ui/hoyo/profile/items/add_img_btn.py +++ b/src/ui/hoyo/profile/items/add_img_btn.py @@ -47,8 +47,6 @@ async def callback(self, i: "INTERACTION") -> None: await i.response.send_modal(modal) await modal.wait() - await self.set_loading_state(i) - image_url = modal.image_url.value if not image_url: return @@ -61,9 +59,11 @@ async def callback(self, i: "INTERACTION") -> None: if not passed: raise InvalidImageURLError + await self.set_loading_state(i) + # Upload the image to iili try: - url = await upload_image(image_url, i.client.session) + url = await upload_image(i.client.session, image_url=image_url) except Exception as e: raise InvalidImageURLError from e diff --git a/src/ui/hoyo/profile/items/card_settings_btn.py b/src/ui/hoyo/profile/items/card_settings_btn.py index 27a28418..64f73f73 100644 --- a/src/ui/hoyo/profile/items/card_settings_btn.py +++ b/src/ui/hoyo/profile/items/card_settings_btn.py @@ -9,6 +9,7 @@ from .card_settings_info_btn import CardSettingsInfoButton from .card_template_select import CardTemplateSelect from .dark_mode_btn import DarkModeButton +from .gen_ai_art_btn import GenerateAIArtButton from .image_select import ImageSelect from .primary_color_btn import PrimaryColorButton from .remove_img_btn import RemoveImageButton @@ -66,6 +67,8 @@ async def callback(self, i: "INTERACTION") -> None: ) ) self.view.add_item(CardSettingsInfoButton()) + + self.view.add_item(GenerateAIArtButton()) self.view.add_item(AddImageButton()) self.view.add_item( RemoveImageButton(self.view._card_settings.current_image in default_arts) diff --git a/src/ui/hoyo/profile/items/gen_ai_art_btn.py b/src/ui/hoyo/profile/items/gen_ai_art_btn.py new file mode 100644 index 00000000..b3ef170a --- /dev/null +++ b/src/ui/hoyo/profile/items/gen_ai_art_btn.py @@ -0,0 +1,95 @@ +from typing import TYPE_CHECKING + +from discord import ButtonStyle, TextStyle +from discord.file import File + +from src.bot.translator import LocaleStr +from src.constants import NSFW_TAGS +from src.exceptions import GuildOnlyFeatureError, NSFWPromptError +from src.ui.components import Button, Modal, TextInput + +from .....utils import upload_image + +if TYPE_CHECKING: + from src.bot.bot import INTERACTION + + from ..view import ProfileView # noqa: F401 + from .image_select import ImageSelect + from .remove_img_btn import RemoveImageButton + + +class GenerateAIArtModal(Modal): + prompt = TextInput( + label=LocaleStr("Prompt", key="profile.generate_ai_art_modal.prompt.label"), + placeholder="navia(genshin impact), foaml dress, idol, beautiful dress, elegant, best quality, aesthetic...", + style=TextStyle.paragraph, + max_length=250, + ) + + negative_prompt = TextInput( + label=LocaleStr( + "Negative Prompt", key="profile.generate_ai_art_modal.negative_prompt.label" + ), + placeholder="bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs...", + style=TextStyle.paragraph, + max_length=200, + required=False, + ) + + +class GenerateAIArtButton(Button): + def __init__(self) -> None: + super().__init__( + label=LocaleStr("Generate AI Art", key="profile.generate_ai_art.button.label"), + style=ButtonStyle.blurple, + row=3, + ) + + async def callback(self, i: "INTERACTION") -> None: + if i.guild is None: + raise GuildOnlyFeatureError + + modal = GenerateAIArtModal( + title=LocaleStr("Generate AI Art", key="profile.generate_ai_art_modal.title") + ) + modal.translate(self.view.locale, self.view.translator) + await i.response.send_modal(modal) + await modal.wait() + + if not modal.prompt.value: + return + + prompt = modal.prompt.value + negative_prompt = modal.negative_prompt.value + if any(tag.lower() in prompt.lower() for tag in NSFW_TAGS): + raise NSFWPromptError + + await self.set_loading_state(i) + + client = i.client.nai_client + bytes_ = await client.generate_image(prompt, negative_prompt) + url = await upload_image(i.client.session, image=bytes_) + + # Add the image URL to db + self.view._card_settings.custom_images.append(url) + self.view._card_settings.current_image = url + await self.view._card_settings.save() + + # Add the new image URL to the image select options + image_select: ImageSelect = self.view.get_item("profile_image_select") + image_select.options_before_split = image_select.generate_options() + image_select.options = image_select.process_options() + # Set the new image as the default (selected) option + image_select.update_options_defaults(values=[url]) + image_select.translate(self.view.locale, self.view.translator) + + # Enable the remove image button + remove_img_btn: RemoveImageButton = self.view.get_item("profile_remove_image") + remove_img_btn.disabled = False + + # Redraw the card + bytes_obj = await self.view.draw_card(i) + bytes_obj.seek(0) + await self.unset_loading_state( + i, attachments=[File(bytes_obj, filename="card.webp")], embed=None + ) diff --git a/src/utils.py b/src/utils.py index 97271981..7a7ed657 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,3 +1,4 @@ +import base64 import datetime import logging import math @@ -72,13 +73,21 @@ def round_down(number: float, decimals: int) -> float: return result -async def upload_image(image_url: str, session: aiohttp.ClientSession) -> str: +async def upload_image( + session: aiohttp.ClientSession, *, image_url: str | None = None, image: bytes | None = None +) -> str: api = "https://freeimage.host/api/1/upload" data = { "key": "6d207e02198a847aa98d0a2a901485a5", "source": image_url, "format": "json", } + + if image is not None: + # Encode image into base64 string + image_base64 = base64.b64encode(image).decode("utf-8") + data["source"] = image_base64 + async with session.post(api, data=data) as resp: resp.raise_for_status()