Skip to content

Commit

Permalink
Add support for Sve.StoreNarrowing() (#102605)
Browse files Browse the repository at this point in the history
* Add Sve.StoreNarrowing()

* Incorporate review comments for Sve.StoreAndZip()

* Fix formatting issues
  • Loading branch information
SwapnilGaikwad authored May 26, 2024
1 parent 35e4aad commit 0588f24
Show file tree
Hide file tree
Showing 13 changed files with 609 additions and 80 deletions.
1 change: 1 addition & 0 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26711,6 +26711,7 @@ bool GenTreeHWIntrinsic::OperIsMemoryStore(GenTree** pAddr) const
case NI_Sve_StoreAndZipx2:
case NI_Sve_StoreAndZipx3:
case NI_Sve_StoreAndZipx4:
case NI_Sve_StoreNarrowing:
addr = Op(2);
break;
#endif // TARGET_ARM64
Expand Down
28 changes: 28 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2479,6 +2479,34 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
break;
}

case NI_Sve_StoreNarrowing:
{
assert(sig->numArgs == 3);
assert(retType == TYP_VOID);

CORINFO_ARG_LIST_HANDLE arg = sig->args;
arg = info.compCompHnd->getArgNext(arg);
CORINFO_CLASS_HANDLE argClass = info.compCompHnd->getArgClass(sig, arg);
CorInfoType ptrType = getBaseJitTypeAndSizeOfSIMDType(argClass);
CORINFO_CLASS_HANDLE tmpClass = NO_CLASS_HANDLE;

// The size of narrowed target elements is determined from the second argument of StoreNarrowing().
// Thus, we first extract the datatype of a pointer passed in the second argument and then store it as the
// auxiliary type of intrinsic. This auxiliary type is then used in the codegen to choose the correct
// instruction to emit.
ptrType = strip(info.compCompHnd->getArgType(sig, arg, &tmpClass));
assert(ptrType == CORINFO_TYPE_PTR);
ptrType = info.compCompHnd->getChildType(argClass, &tmpClass);
assert(ptrType < simdBaseJitType);

op3 = impPopStack().val;
op2 = impPopStack().val;
op1 = impPopStack().val;
retNode = gtNewSimdHWIntrinsicNode(retType, op1, op2, op3, intrinsic, simdBaseJitType, simdSize);
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(ptrType);
break;
}

default:
{
return nullptr;
Expand Down
9 changes: 9 additions & 0 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,10 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
ins = varTypeIsUnsigned(intrin.baseType) ? INS_umsubl : INS_smsubl;
break;

case NI_Sve_StoreNarrowing:
ins = HWIntrinsicInfo::lookupIns(intrin.id, node->GetAuxiliaryType());
break;

default:
ins = HWIntrinsicInfo::lookupIns(intrin.id, intrin.baseType);
break;
Expand Down Expand Up @@ -1773,6 +1777,11 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
break;
}

case NI_Sve_StoreNarrowing:
opt = emitter::optGetSveInsOpt(emitTypeSize(intrin.baseType));
GetEmitter()->emitIns_R_R_R_I(ins, emitSize, op3Reg, op1Reg, op2Reg, 0, opt);
break;

case NI_Sve_UnzipEven:
case NI_Sve_UnzipOdd:
case NI_Sve_ZipHigh:
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ HARDWARE_INTRINSIC(Sve, SignExtend8,
HARDWARE_INTRINSIC(Sve, SignExtendWideningLower, -1, 1, true, {INS_sve_sunpklo, INS_invalid, INS_sve_sunpklo, INS_invalid, INS_sve_sunpklo, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg)
HARDWARE_INTRINSIC(Sve, SignExtendWideningUpper, -1, 1, true, {INS_sve_sunpkhi, INS_invalid, INS_sve_sunpkhi, INS_invalid, INS_sve_sunpkhi, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg)
HARDWARE_INTRINSIC(Sve, StoreAndZip, -1, 3, true, {INS_sve_st1b, INS_sve_st1b, INS_sve_st1h, INS_sve_st1h, INS_sve_st1w, INS_sve_st1w, INS_sve_st1d, INS_sve_st1d, INS_sve_st1w, INS_sve_st1d}, HW_Category_MemoryStore, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_ExplicitMaskedOperation|HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, StoreNarrowing, -1, 3, true, {INS_sve_st1b, INS_sve_st1b, INS_sve_st1h, INS_sve_st1h, INS_sve_st1w, INS_sve_st1w, INS_sve_st1d, INS_sve_st1d, INS_invalid, INS_invalid}, HW_Category_MemoryStore, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_ExplicitMaskedOperation|HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, Subtract, -1, 2, true, {INS_sve_sub, INS_sve_sub, INS_sve_sub, INS_sve_sub, INS_sve_sub, INS_sve_sub, INS_sve_sub, INS_sve_sub, INS_sve_fsub, INS_sve_fsub}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_OptionalEmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, SubtractSaturate, -1, 2, true, {INS_sve_sqsub, INS_sve_uqsub, INS_sve_sqsub, INS_sve_uqsub, INS_sve_sqsub, INS_sve_uqsub, INS_sve_sqsub, INS_sve_uqsub, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_OptionalEmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, UnzipEven, -1, 2, true, {INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2998,6 +2998,80 @@ internal Arm64() { }
/// ST4D {Zdata0.D - Zdata3.D}, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreAndZip(Vector<ulong> mask, ulong* address, (Vector<ulong> Value1, Vector<ulong> Value2, Vector<ulong> Value3, Vector<ulong> Value4) data) { throw new PlatformNotSupportedException(); }
/// Truncate to 8 bits and store

/// <summary>
/// void svst1b[_s16](svbool_t pg, int8_t *base, svint16_t data)
/// ST1B Zdata.H, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<short> mask, sbyte* address, Vector<short> data) { throw new PlatformNotSupportedException(); }


/// <summary>
/// void svst1b[_s32](svbool_t pg, int8_t *base, svint32_t data)
/// ST1B Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<int> mask, sbyte* address, Vector<int> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1h[_s32](svbool_t pg, int16_t *base, svint32_t data)
/// ST1H Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<int> mask, short* address, Vector<int> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1b[_s64](svbool_t pg, int8_t *base, svint64_t data)
/// ST1B Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<long> mask, sbyte* address, Vector<long> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1h[_s64](svbool_t pg, int16_t *base, svint64_t data)
/// ST1H Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<long> mask, short* address, Vector<long> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1w[_s64](svbool_t pg, int32_t *base, svint64_t data)
/// ST1W Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<long> mask, int* address, Vector<long> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1b[_u16](svbool_t pg, uint8_t *base, svuint16_t data)
/// ST1B Zdata.H, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ushort> mask, byte* address, Vector<ushort> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1b[_u32](svbool_t pg, uint8_t *base, svuint32_t data)
/// ST1B Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<uint> mask, byte* address, Vector<uint> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1h[_u32](svbool_t pg, uint16_t *base, svuint32_t data)
/// ST1H Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<uint> mask, ushort* address, Vector<uint> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1b[_u64](svbool_t pg, uint8_t *base, svuint64_t data)
/// ST1B Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ulong> mask, byte* address, Vector<ulong> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1h[_u64](svbool_t pg, uint16_t *base, svuint64_t data)
/// ST1H Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ulong> mask, ushort* address, Vector<ulong> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1w[_u64](svbool_t pg, uint32_t *base, svuint64_t data)
/// ST1W Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ulong> mask, uint* address, Vector<ulong> data) { throw new PlatformNotSupportedException(); }


/// Subtract : Subtract
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3094,6 +3094,80 @@ internal Arm64() { }
/// ST4D {Zdata0.D - Zdata3.D}, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreAndZip(Vector<ulong> mask, ulong* address, (Vector<ulong> Value1, Vector<ulong> Value2, Vector<ulong> Value3, Vector<ulong> Value4) data) => StoreAndZip(mask, address, data);
/// Truncate to 8 bits and store


/// <summary>
/// void svst1b[_s16](svbool_t pg, int8_t *base, svint16_t data)
/// ST1B Zdata.H, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<short> mask, sbyte* address, Vector<short> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1b[_s32](svbool_t pg, int8_t *base, svint32_t data)
/// ST1B Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<int> mask, sbyte* address, Vector<int> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1h[_s32](svbool_t pg, int16_t *base, svint32_t data)
/// ST1H Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<int> mask, short* address, Vector<int> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1b[_s64](svbool_t pg, int8_t *base, svint64_t data)
/// ST1B Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<long> mask, sbyte* address, Vector<long> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1h[_s64](svbool_t pg, int16_t *base, svint64_t data)
/// ST1H Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<long> mask, short* address, Vector<long> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1w[_s64](svbool_t pg, int32_t *base, svint64_t data)
/// ST1W Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<long> mask, int* address, Vector<long> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1b[_u16](svbool_t pg, uint8_t *base, svuint16_t data)
/// ST1B Zdata.H, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ushort> mask, byte* address, Vector<ushort> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1b[_u32](svbool_t pg, uint8_t *base, svuint32_t data)
/// ST1B Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<uint> mask, byte* address, Vector<uint> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1h[_u32](svbool_t pg, uint16_t *base, svuint32_t data)
/// ST1H Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<uint> mask, ushort* address, Vector<uint> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1b[_u64](svbool_t pg, uint8_t *base, svuint64_t data)
/// ST1B Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ulong> mask, byte* address, Vector<ulong> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1h[_u64](svbool_t pg, uint16_t *base, svuint64_t data)
/// ST1H Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ulong> mask, ushort* address, Vector<ulong> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1w[_u64](svbool_t pg, uint32_t *base, svuint64_t data)
/// ST1W Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ulong> mask, uint* address, Vector<ulong> data) => StoreNarrowing(mask, address, data);


/// Subtract : Subtract
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4588,6 +4588,19 @@ internal Arm64() { }
public static unsafe void StoreAndZip(System.Numerics.Vector<ulong> mask, ulong* address, (System.Numerics.Vector<ulong> Value1, System.Numerics.Vector<ulong> Value2, System.Numerics.Vector<ulong> Value3) data) { throw null; }
public static unsafe void StoreAndZip(System.Numerics.Vector<ulong> mask, ulong* address, (System.Numerics.Vector<ulong> Value1, System.Numerics.Vector<ulong> Value2, System.Numerics.Vector<ulong> Value3, System.Numerics.Vector<ulong> Value4) data) { throw null; }

public static unsafe void StoreNarrowing(System.Numerics.Vector<short> mask, sbyte* address, System.Numerics.Vector<short> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<int> mask, sbyte* address, System.Numerics.Vector<int> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<int> mask, short* address, System.Numerics.Vector<int> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<long> mask, sbyte* address, System.Numerics.Vector<long> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<long> mask, short* address, System.Numerics.Vector<long> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<long> mask, int* address, System.Numerics.Vector<long> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<ushort> mask, byte* address, System.Numerics.Vector<ushort> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<uint> mask, byte* address, System.Numerics.Vector<uint> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<uint> mask, ushort* address, System.Numerics.Vector<uint> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<ulong> mask, byte* address, System.Numerics.Vector<ulong> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<ulong> mask, ushort* address, System.Numerics.Vector<ulong> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<ulong> mask, uint* address, System.Numerics.Vector<ulong> data) { throw null; }

public static System.Numerics.Vector<sbyte> Subtract(System.Numerics.Vector<sbyte> left, System.Numerics.Vector<sbyte> right) { throw null; }
public static System.Numerics.Vector<short> Subtract(System.Numerics.Vector<short> left, System.Numerics.Vector<short> right) { throw null; }
public static System.Numerics.Vector<int> Subtract(System.Numerics.Vector<int> left, System.Numerics.Vector<int> right) { throw null; }
Expand Down
Loading

0 comments on commit 0588f24

Please sign in to comment.