mirror of
https://github.com/vladmandic/automatic
synced 2026-04-09 10:11:53 +02:00
@@ -31,6 +31,7 @@ const jsConfig = defineConfig([
|
||||
...globals.jquery,
|
||||
panzoom: 'readonly',
|
||||
authFetch: 'readonly',
|
||||
initServerInfo: 'readonly',
|
||||
log: 'readonly',
|
||||
debug: 'readonly',
|
||||
error: 'readonly',
|
||||
|
||||
Submodule extensions-builtin/sdnext-modernui updated: f90df02fe0...ba8955c02f
62
installer.py
62
installer.py
@@ -20,6 +20,11 @@ class Dot(dict): # dot notation access to dictionary attributes
|
||||
__setattr__ = dict.__setitem__
|
||||
__delattr__ = dict.__delitem__
|
||||
|
||||
class Torch(dict):
|
||||
def set(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
self[k] = v
|
||||
|
||||
version = {
|
||||
'app': 'sd.next',
|
||||
'updated': 'unknown',
|
||||
@@ -75,6 +80,8 @@ control_extensions = [ # 3rd party extensions marked as safe for control ui
|
||||
'IP Adapters',
|
||||
'Remove background',
|
||||
]
|
||||
gpu_info = []
|
||||
torch_info = Torch()
|
||||
|
||||
|
||||
try:
|
||||
@@ -588,7 +595,7 @@ def install_rocm_zluda():
|
||||
if device_id < len(amd_gpus):
|
||||
device = amd_gpus[device_id]
|
||||
|
||||
if sys.platform == "win32" and not args.use_zluda and device is not None and device.therock is not None and not installed("rocm"):
|
||||
if sys.platform == "win32" and (not args.use_zluda) and (device is not None) and (device.therock is not None) and not installed("rocm"):
|
||||
check_python(supported_minors=[11, 12, 13], reason='ROCm backend requires a Python version between 3.11 and 3.13')
|
||||
install(f"rocm[devel,libraries] --index-url https://rocm.nightlies.amd.com/{device.therock}")
|
||||
rocm.refresh()
|
||||
@@ -750,6 +757,18 @@ def check_cudnn():
|
||||
os.environ['CUDA_PATH'] = cuda_path
|
||||
|
||||
|
||||
def get_cuda_arch(capability):
|
||||
major, minor = capability
|
||||
mapping = {9: "Hopper",
|
||||
8: "Ada Lovelace" if minor == 9 else "Ampere",
|
||||
7: "Turing" if minor == 5 else "Volta",
|
||||
6: "Pascal",
|
||||
5: "Maxwell",
|
||||
3: "Kepler"}
|
||||
name = mapping.get(major, "Unknown")
|
||||
return f"{major}.{minor} {name}"
|
||||
|
||||
|
||||
# check torch version
|
||||
def check_torch():
|
||||
log.info('Torch: verifying installation')
|
||||
@@ -832,6 +851,7 @@ def check_torch():
|
||||
log.info(f'Torch backend: type=IPEX version={ipex.__version__}')
|
||||
except Exception:
|
||||
pass
|
||||
torch_info.set(version=torch.__version__)
|
||||
if 'cpu' in torch.__version__:
|
||||
if is_cuda_available:
|
||||
if args.use_cuda:
|
||||
@@ -845,20 +865,44 @@ def check_torch():
|
||||
install(torch_command, 'torch torchvision', quiet=True, reinstall=True, force=True) # foce reinstall
|
||||
else:
|
||||
log.warning(f'Torch: version="{torch.__version__}" CPU version installed and ROCm is available - consider reinstalling')
|
||||
if args.use_openvino:
|
||||
torch_info.set(type='openvino')
|
||||
else:
|
||||
torch_info.set(type='cpu')
|
||||
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available() and allow_ipex:
|
||||
if shutil.which('icpx') is not None:
|
||||
log.info(f'{os.popen("icpx --version").read().rstrip()}')
|
||||
torch_info.set(type='xpu', oneapi=torch.xpu.runtime_version(), dpc=torch.xpu.dpcpp_version(), driver=torch.xpu.driver_version())
|
||||
for device in range(torch.xpu.device_count()):
|
||||
log.info(f'Torch detected: gpu="{torch.xpu.get_device_name(device)}" vram={round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} units={torch.xpu.get_device_properties(device).max_compute_units}')
|
||||
gpu = {
|
||||
'gpu': torch.xpu.get_device_name(device),
|
||||
'vram': round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024),
|
||||
'units': torch.xpu.get_device_properties(device).max_compute_units,
|
||||
}
|
||||
log.info(f'Torch detected: {gpu}')
|
||||
gpu_info.append(gpu)
|
||||
|
||||
elif torch.cuda.is_available() and (allow_cuda or allow_rocm):
|
||||
if torch.version.cuda and allow_cuda:
|
||||
log.info(f'Torch backend: version="{torch.__version__}" type=CUDA CUDA={torch.version.cuda} cuDNN={torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}')
|
||||
if args.use_zluda:
|
||||
torch_info.set(type="zluda", cuda=torch.version.cuda)
|
||||
elif torch.version.cuda and allow_cuda:
|
||||
torch_info.set(type='cuda', cuda=torch.version.cuda, cudnn=torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else 'N/A')
|
||||
elif torch.version.hip and allow_rocm:
|
||||
log.info(f'Torch backend: version="{torch.__version__}" type=ROCm HIP={torch.version.hip}')
|
||||
torch_info.set(type='rocm', hip=torch.version.hip)
|
||||
else:
|
||||
log.warning('Unknown Torch backend')
|
||||
log.info(f"Torch backend: {torch_info}")
|
||||
for device in [torch.cuda.device(i) for i in range(torch.cuda.device_count())]:
|
||||
log.info(f'Torch detected: gpu="{torch.cuda.get_device_name(device)}" vram={round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} arch={torch.cuda.get_device_capability(device)} cores={torch.cuda.get_device_properties(device).multi_processor_count}')
|
||||
gpu = {
|
||||
'gpu': torch.cuda.get_device_name(device),
|
||||
'vram': round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024),
|
||||
'arch': get_cuda_arch(torch.cuda.get_device_capability(device)),
|
||||
'cores': torch.cuda.get_device_properties(device).multi_processor_count,
|
||||
}
|
||||
gpu_info.append(gpu)
|
||||
log.info(f'Torch detected: {gpu}')
|
||||
|
||||
else:
|
||||
try:
|
||||
if args.use_directml and allow_directml:
|
||||
@@ -867,7 +911,11 @@ def check_torch():
|
||||
log.warning(f'Torch backend: DirectML ({dml_ver})')
|
||||
log.warning('DirectML: end-of-life')
|
||||
for i in range(0, torch_directml.device_count()):
|
||||
log.info(f'Torch detected GPU: {torch_directml.device_name(i)}')
|
||||
gpu = {
|
||||
'gpu': torch_directml.device_name(i),
|
||||
}
|
||||
gpu_info.append(gpu)
|
||||
log.info(f'Torch detected GPU: {gpu}')
|
||||
except Exception:
|
||||
log.warning("Torch reports CUDA not available")
|
||||
except Exception as e:
|
||||
|
||||
@@ -24,9 +24,9 @@ async function authFetch(url, options = {}) {
|
||||
let res;
|
||||
try {
|
||||
res = await fetch(url, options);
|
||||
if (!res.ok) error('fetch', { status: res.status, url, user, token });
|
||||
if (!res.ok) error('fetch', { status: res?.status || 503, url, user, token });
|
||||
} catch (err) {
|
||||
error('fetch', { status: res.status, url, user, token, error: err });
|
||||
error('fetch', { status: res?.status || 503, url, user, token, error: err });
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ async function updateGPU() {
|
||||
const gpuEl = document.getElementById('gpu');
|
||||
const gpuTable = document.getElementById('gpu-table');
|
||||
try {
|
||||
const res = await authFetch(`${window.api}/gpu`);
|
||||
const res = await authFetch(`${window.api}/gpu-smi`);
|
||||
if (!res.ok) {
|
||||
clearInterval(gpuInterval);
|
||||
gpuEl.style.display = 'none';
|
||||
|
||||
@@ -39,12 +39,14 @@ class Api:
|
||||
def register(self):
|
||||
# fetch js/css
|
||||
self.add_api_route("/js", server.get_js, methods=["GET"], auth=False)
|
||||
|
||||
# server api
|
||||
self.add_api_route("/sdapi/v1/motd", server.get_motd, methods=["GET"], response_model=str)
|
||||
self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=list[str])
|
||||
self.add_api_route("/sdapi/v1/log", server.post_log, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/start", self.get_session_start, methods=["GET"])
|
||||
self.add_api_route("/sdapi/v1/version", server.get_version, methods=["GET"])
|
||||
self.add_api_route("/sdapi/v1/torch", server.get_torch, methods=["GET"])
|
||||
self.add_api_route("/sdapi/v1/status", server.get_status, methods=["GET"], response_model=models.ResStatus)
|
||||
self.add_api_route("/sdapi/v1/platform", server.get_platform, methods=["GET"])
|
||||
self.add_api_route("/sdapi/v1/progress", server.get_progress, methods=["GET"], response_model=models.ResProgress)
|
||||
@@ -54,7 +56,8 @@ class Api:
|
||||
self.add_api_route("/sdapi/v1/shutdown", server.post_shutdown, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/memory", server.get_memory, methods=["GET"], response_model=models.ResMemory)
|
||||
self.add_api_route("/sdapi/v1/cmd-flags", server.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||
self.add_api_route("/sdapi/v1/gpu", gpu.get_gpu_status, methods=["GET"], response_model=list[models.ResGPU])
|
||||
self.add_api_route("/sdapi/v1/gpu", gpu.get_gpu, methods=["GET"])
|
||||
self.add_api_route("/sdapi/v1/gpu-smi", gpu.get_gpu_smi, methods=["GET"], response_model=list[models.ResGPU])
|
||||
|
||||
# core api using locking
|
||||
self.add_api_route("/sdapi/v1/txt2img", self.generate.post_text2img, methods=["POST"], response_model=models.ResTxt2Img, tags=["Generation"])
|
||||
|
||||
@@ -5,7 +5,17 @@ from modules.logger import log
|
||||
device = None
|
||||
|
||||
|
||||
def get_gpu_status():
|
||||
def get_gpu():
|
||||
import installer
|
||||
res = {}
|
||||
if len(installer.gpu_info) == 1:
|
||||
return installer.gpu_info[0]
|
||||
for i, item in enumerate(installer.gpu_info):
|
||||
res[i] = item
|
||||
return res
|
||||
|
||||
|
||||
def get_gpu_smi():
|
||||
"""Return real-time GPU metrics (utilization, temperature, memory, clock speeds) via vendor-specific APIs (NVML, ROCm SMI, XPU SMI)."""
|
||||
global device # pylint: disable=global-statement
|
||||
if device is None:
|
||||
@@ -37,5 +47,5 @@ class ResGPU(BaseModel):
|
||||
|
||||
if __name__ == '__main__':
|
||||
from rich import print as rprint
|
||||
for gpu in get_gpu_status():
|
||||
for gpu in get_gpu_smi():
|
||||
rprint(gpu)
|
||||
|
||||
@@ -19,7 +19,7 @@ ignore_endpoints = [
|
||||
'/sdapi/v1/version',
|
||||
'/sdapi/v1/log',
|
||||
'/sdapi/v1/browser',
|
||||
'/sdapi/v1/gpu',
|
||||
'/sdapi/v1/gpu-smi',
|
||||
'/sdapi/v1/network/thumb',
|
||||
'/sdapi/v1/progress',
|
||||
]
|
||||
|
||||
@@ -8,14 +8,6 @@ from modules import shared
|
||||
from modules.logger import log
|
||||
from modules.api import models, helpers
|
||||
|
||||
def _get_version():
|
||||
return installer.get_version()
|
||||
|
||||
|
||||
def post_shutdown():
|
||||
log.info('Shutdown request received')
|
||||
import sys
|
||||
sys.exit(0)
|
||||
|
||||
def get_js(request: Request):
|
||||
file = request.query_params.get("file", None)
|
||||
@@ -43,32 +35,34 @@ def get_js(request: Request):
|
||||
media_type = 'application/octet-stream'
|
||||
return FileResponse(file, media_type=media_type)
|
||||
|
||||
def get_version():
|
||||
return installer.get_version()
|
||||
|
||||
def get_motd():
|
||||
import requests
|
||||
motd = ''
|
||||
ver = _get_version()
|
||||
if ver.get('updated', None) is not None:
|
||||
motd = f"version <b>{ver['commit']} {ver['updated']}</b> <span style='color: var(--primary-500)'>{ver['url'].split('/')[-1]}</span><br>" # pylint: disable=use-maxsplit-arg
|
||||
motd = ""
|
||||
ver = get_version()
|
||||
if ver.get("updated", None) is not None:
|
||||
motd = f"version <b>{ver['commit']} {ver['updated']}</b> <span style='color: var(--primary-500)'>{ver['url'].split('/')[-1]}</span><br>" # pylint: disable=use-maxsplit-arg
|
||||
if shared.opts.motd:
|
||||
try:
|
||||
res = requests.get('https://vladmandic.github.io/sdnext/motd', timeout=3)
|
||||
res = requests.get("https://vladmandic.github.io/sdnext/motd", timeout=3)
|
||||
if res.status_code == 200:
|
||||
msg = (res.text or '').strip()
|
||||
log.info(f'MOTD: {msg if len(msg) > 0 else "N/A"}')
|
||||
msg = (res.text or "").strip()
|
||||
log.info(f"MOTD: {msg if len(msg) > 0 else 'N/A'}")
|
||||
motd += res.text
|
||||
else:
|
||||
log.error(f'MOTD: {res.status_code}')
|
||||
log.error(f"MOTD: {res.status_code}")
|
||||
except Exception as err:
|
||||
log.error(f'MOTD: {err}')
|
||||
log.error(f"MOTD: {err}")
|
||||
return motd
|
||||
|
||||
def get_version():
|
||||
return _get_version()
|
||||
|
||||
def get_platform():
|
||||
from installer import get_platform as installer_get_platform
|
||||
from modules.loader import get_packages as loader_get_packages
|
||||
return { **installer_get_platform(), **loader_get_packages() }
|
||||
return { **installer.get_platform(), **loader_get_packages() }
|
||||
|
||||
def get_torch():
|
||||
return dict(installer.torch_info)
|
||||
|
||||
def get_log(req: models.ReqGetLog = Depends()):
|
||||
lines = log.buffer[:req.lines] if req.lines > 0 else log.buffer.copy()
|
||||
@@ -85,6 +79,11 @@ def post_log(req: models.ReqPostLog):
|
||||
log.error(f'UI: {req.error}')
|
||||
return {}
|
||||
|
||||
def post_shutdown():
|
||||
log.info("Shutdown request received")
|
||||
import sys
|
||||
sys.exit(0)
|
||||
|
||||
def get_cmd_flags():
|
||||
return vars(shared.cmd_opts)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from functools import wraps
|
||||
import torch
|
||||
from modules import rocm
|
||||
from modules.errors import log
|
||||
from installer import install, installed
|
||||
from installer import install, installed, torch_info
|
||||
|
||||
|
||||
def set_dynamic_attention():
|
||||
@@ -10,6 +10,7 @@ def set_dynamic_attention():
|
||||
sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
from modules.sd_hijack_dynamic_atten import dynamic_scaled_dot_product_attention
|
||||
torch.nn.functional.scaled_dot_product_attention = dynamic_scaled_dot_product_attention
|
||||
torch_info.set(attention='dynamic')
|
||||
return sdpa_pre_dyanmic_atten
|
||||
except Exception as err:
|
||||
log.error(f'Torch attention: type="dynamic attention" {err}')
|
||||
@@ -20,6 +21,7 @@ def set_triton_flash_attention(backend: str):
|
||||
try:
|
||||
if backend in {"rocm", "zluda"}: # flash_attn_triton_amd only works with AMD
|
||||
from modules.flash_attn_triton_amd import interface_fa
|
||||
|
||||
sdpa_pre_triton_flash_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(sdpa_pre_triton_flash_atten)
|
||||
def sdpa_triton_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
|
||||
@@ -42,6 +44,7 @@ def set_triton_flash_attention(backend: str):
|
||||
kwargs["enable_gqa"] = enable_gqa
|
||||
return sdpa_pre_triton_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs)
|
||||
torch.nn.functional.scaled_dot_product_attention = sdpa_triton_flash_atten
|
||||
torch_info.set(attention='triton')
|
||||
log.debug('Torch attention: type="Triton Flash attention"')
|
||||
except Exception as err:
|
||||
log.error(f'Torch attention: type="Triton Flash attention" {err}')
|
||||
@@ -78,6 +81,7 @@ def set_flex_attention():
|
||||
return flex_attention(query, key, value, score_mod=score_mod, block_mask=block_mask, scale=scale, enable_gqa=enable_gqa)
|
||||
|
||||
torch.nn.functional.scaled_dot_product_attention = sdpa_flex_atten
|
||||
torch_info.set(attention="flex")
|
||||
log.debug('Torch attention: type="Flex attention"')
|
||||
except Exception as err:
|
||||
log.error(f'Torch attention: type="Flex attention" {err}')
|
||||
@@ -93,6 +97,7 @@ def set_ck_flash_attention(backend: str, device: torch.device):
|
||||
else:
|
||||
install('flash-attn')
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(sdpa_pre_flash_atten)
|
||||
def sdpa_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
|
||||
@@ -120,6 +125,7 @@ def set_ck_flash_attention(backend: str, device: torch.device):
|
||||
kwargs["enable_gqa"] = enable_gqa
|
||||
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs)
|
||||
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
|
||||
torch_info.set(attention="flash")
|
||||
log.debug('Torch attention: type="Flash attention"')
|
||||
except Exception as err:
|
||||
log.error(f'Torch attention: type="Flash attention" {err}')
|
||||
@@ -174,6 +180,7 @@ def set_sage_attention(backend: str, device: torch.device):
|
||||
kwargs["enable_gqa"] = enable_gqa
|
||||
return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs)
|
||||
torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten
|
||||
torch_info.set(attention="sage")
|
||||
log.debug(f'Torch attention: type="Sage attention" backend={"cuda" if use_cuda_backend else "auto"}')
|
||||
except Exception as err:
|
||||
log.error(f'Torch attention: type="Sage attention" {err}')
|
||||
@@ -208,19 +215,22 @@ def set_diffusers_attention(pipe, quiet:bool=False):
|
||||
|
||||
log.quiet(quiet, f'Setting model: attention="{shared.opts.cross_attention_optimization}"')
|
||||
if shared.opts.cross_attention_optimization == "Disabled":
|
||||
pass # do nothing
|
||||
elif shared.opts.cross_attention_optimization == "Scaled-Dot-Product": # The default set by Diffusers
|
||||
torch_info.set(attention="disabled")
|
||||
elif shared.opts.cross_attention_optimization == "Scaled-Dot-Product": # The default set by Diffusers
|
||||
torch_info.set(attention="sdpa")
|
||||
# set_attn(pipe, p.AttnProcessor2_0(), name="Scaled-Dot-Product")
|
||||
pass
|
||||
elif shared.opts.cross_attention_optimization == "xFormers":
|
||||
if hasattr(pipe, 'enable_xformers_memory_efficient_attention'):
|
||||
torch_info.set(attention="xformers")
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
log.warning(f"Attention: xFormers is not compatible with {pipe.__class__.__name__}")
|
||||
elif shared.opts.cross_attention_optimization == "Batch matrix-matrix":
|
||||
torch_info.set(attention="bmm")
|
||||
set_attn(pipe, p.AttnProcessor(), name="Batch matrix-matrix")
|
||||
elif shared.opts.cross_attention_optimization == "Dynamic Attention BMM":
|
||||
from modules.sd_hijack_dynamic_atten import DynamicAttnProcessorBMM
|
||||
torch_info.set(attention="dynamic_bmm")
|
||||
set_attn(pipe, DynamicAttnProcessorBMM(), name="Dynamic Attention BMM")
|
||||
|
||||
if shared.opts.attention_slicing != "Default" and hasattr(pipe, "enable_attention_slicing") and hasattr(pipe, "disable_attention_slicing"):
|
||||
|
||||
@@ -4,8 +4,10 @@ import time
|
||||
import contextlib
|
||||
import importlib.metadata
|
||||
import torch
|
||||
from installer import torch_info
|
||||
from modules.logger import log
|
||||
from modules import rocm, attention
|
||||
from modules.errors import log, display, install as install_traceback
|
||||
from modules.errors import display, install as install_traceback
|
||||
|
||||
|
||||
debug = os.environ.get('SD_DEVICE_DEBUG', None) is not None
|
||||
@@ -398,17 +400,35 @@ def test_triton(early: bool = False):
|
||||
test_triton_func(torch.randn(16, device=device), torch.randn(16, device=device), torch.randn(16, device=device))
|
||||
triton_ok = True
|
||||
else:
|
||||
torch_info.set(triton=False)
|
||||
triton_ok = False
|
||||
except Exception as e:
|
||||
torch_info.set(triton=False)
|
||||
triton_ok = False
|
||||
line = str(e).splitlines()[0]
|
||||
log.warning(f"Triton test fail: {line}")
|
||||
if debug:
|
||||
from modules import errors
|
||||
errors.display(e, 'Triton')
|
||||
triton_version = False
|
||||
if triton_ok:
|
||||
if triton_version is None:
|
||||
try:
|
||||
import torch._inductor.triton as torch_triton
|
||||
|
||||
triton_version = torch_triton.__version__
|
||||
except Exception:
|
||||
pass
|
||||
if triton_version is None:
|
||||
try:
|
||||
import triton
|
||||
triton_version = triton.__version__
|
||||
except Exception:
|
||||
pass
|
||||
torch_info.set(triton=triton_version)
|
||||
t1 = time.time()
|
||||
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
||||
log.debug(f'Triton: pass={triton_ok} fn={fn} time={t1-t0:.2f}')
|
||||
log.debug(f'Triton: pass={triton_ok} version={triton_version} fn={fn} time={t1-t0:.2f}')
|
||||
if not triton_ok and opts is not None:
|
||||
opts.sdnq_dequantize_compile = False
|
||||
return triton_ok
|
||||
@@ -469,6 +489,7 @@ def set_sdpa_params():
|
||||
torch.backends.cuda.enable_math_sdp('Math' in opts.sdp_options or 'Math attention' in opts.sdp_options)
|
||||
if hasattr(torch.backends.cuda, "allow_fp16_bf16_reduction_math_sdp"): # only valid for torch >= 2.5
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||
torch_info.set(attention="sdpa")
|
||||
log.debug(f'Torch attention: type="sdpa" kernels={opts.sdp_options} overrides={opts.sdp_overrides}')
|
||||
except Exception as err:
|
||||
log.warning(f'Torch attention: type="sdpa" {err}')
|
||||
@@ -559,6 +580,10 @@ def set_dtype():
|
||||
inference_context = contextlib.nullcontext
|
||||
else:
|
||||
inference_context = torch.no_grad
|
||||
if dtype == dtype_vae:
|
||||
torch_info.set(dtype=str(dtype))
|
||||
else:
|
||||
torch_info.set(dtype=str(dtype), vae=str(dtype_vae))
|
||||
|
||||
|
||||
def set_cuda_params():
|
||||
|
||||
Reference in New Issue
Block a user