Skip to content

Commit

Permalink
Index size (#272)
Browse files Browse the repository at this point in the history
* temp solution for guess index length

* add guess_dtype function in utils to help apply_spans_* functions to decide the dtype if dest not specified

* github runner won't allow 16G for creating the 2**32+1 array with int32, try int8

* update dtype to int8 in test_utils to avoid memory usage in github

* updated version for guess the data type for dest array:
1, construction of dest array now moved to ops.apply_span*
2, the dest array is src dtype in min/max operations, is span dtype in index_min/index_max operations
3, the span dtype is returned by np.nonzero which is np.int64 by default

* minor update

* add global long index size, add int32 return for span
todo: add int32 return for span with njit functions. Now the njit is not flexible with int32/int64, tried np.int32/np.int64, 'int32'/'int64', return astype('int32'), none of them are working.

* specify int32/64 type for get_span
  • Loading branch information
deng113jie authored Mar 31, 2022
1 parent 66697ec commit b9a9d64
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 73 deletions.
9 changes: 3 additions & 6 deletions exetera/core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2288,8 +2288,7 @@ def _apply_spans_src(source: Field,
raise ValueError("if 'in_place is True, 'target' must be None")

spans_ = val.array_from_field_or_lower('spans', spans)
results = np.zeros(len(spans) - 1, dtype=source.data.dtype)
predicate(spans_, source.data[:], results)
results = predicate(spans_, source.data[:])

if in_place is True:
if not source._write_enabled:
Expand Down Expand Up @@ -2322,8 +2321,7 @@ def _apply_spans_indexed_src(source: Field,
spans_ = val.array_from_field_or_lower('spans', spans)

# step 1: get the indices through the index predicate
results = np.zeros(len(spans) - 1, dtype=np.int64)
predicate(spans_, source.indices[:], source.values[:], results)
results = predicate(spans_, source.indices[:], source.values[:])

# step 2: run apply_index on the source
return FieldDataOps.apply_index_to_indexed_field(source, results, target, in_place)
Expand All @@ -2341,8 +2339,7 @@ def _apply_spans_indexed_no_src(source: Field,
spans_ = val.array_from_field_or_lower('spans', spans)

# step 1: get the indices through the index predicate
results = np.zeros(len(spans) - 1, dtype=np.int64)
predicate(spans_, results)
results = predicate(spans_)

# step 2: run apply_index on the source
return FieldDataOps.apply_index_to_indexed_field(source, results, target, in_place)
Expand Down
123 changes: 81 additions & 42 deletions exetera/core/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,10 @@ def get_spans_for_field(ndarray):

results[0] = True
results[-1] = True
return np.nonzero(results)[0]
if len(ndarray) < utils.INT64_INDEX_LENGTH:
return np.nonzero(results)[0].astype('int32')
else:
return np.nonzero(results)[0] # int64 by default


@exetera_njit
Expand All @@ -680,37 +683,52 @@ def _get_spans_for_2_fields_by_spans(span0, span1):
return spans


@exetera_njit
def _get_spans_for_2_fields(ndarray0, ndarray1):
if len(ndarray0) > utils.INT64_INDEX_LENGTH or len(ndarray1) > utils.INT64_INDEX_LENGTH:
spans = np.zeros(len(ndarray0) + 1, dtype=np.int64)
else:
spans = np.zeros(len(ndarray0) + 1, dtype=np.int32)
spans = _get_spans_for_2_fields_njit(ndarray0, ndarray1, spans)
return spans


@exetera_njit
def _get_spans_for_2_fields_njit(ndarray0, ndarray1, spans):
count = 0
spans = np.zeros(len(ndarray0)+1, dtype=np.uint32)
spans[0] = 0
for i in np.arange(1, len(ndarray0)):
if ndarray0[i] != ndarray0[i-1] or ndarray1[i] != ndarray1[i-1]:
if ndarray0[i] != ndarray0[i - 1] or ndarray1[i] != ndarray1[i - 1]:
count += 1
spans[count] = i
spans[count+1] = len(ndarray0)
return spans[:count+2]
spans[count + 1] = len(ndarray0)
return spans[:count + 2]



@exetera_njit
def _get_spans_for_multi_fields(fields_data):
length = len(fields_data[0]) # assume all fields are equal length
if length > utils.INT64_INDEX_LENGTH:
spans = np.zeros(length + 1, dtype=np.int64)
else:
spans = np.zeros(length + 1, dtype=np.int32)
return _get_spans_for_multi_fields_njit(fields_data, spans) # call the njit func to boost performance


@exetera_njit
def _get_spans_for_multi_fields_njit(fields_data, spans):
count = 0
length = len(fields_data[0])
spans = np.zeros(length + 1, dtype = np.uint32)
spans[0] = 0

for i in np.arange(1, length):
not_equal = False
for f_d in fields_data:
if f_d[i] != f_d[i - 1]:
not_equal = True
break

if not_equal:
count += 1
spans[count] = i

spans[count + 1] = length
return spans[:count + 2]

Expand Down Expand Up @@ -763,8 +781,10 @@ def _get_spans_for_index_string_field(indices,values):
return result



@exetera_njit
def apply_spans_index_of_min(spans, src_array, dest_array):
def apply_spans_index_of_min(spans, src_array, dest_array=None):
dest_array = np.zeros(len(spans) - 1, dtype=spans.dtype) if dest_array is None else dest_array
for i in range(len(spans)-1):
cur = spans[i]
next = spans[i+1]
Expand All @@ -778,7 +798,8 @@ def apply_spans_index_of_min(spans, src_array, dest_array):


@exetera_njit
def apply_spans_index_of_min_indexed(spans, src_indices, src_values, dest_array):
def apply_spans_index_of_min_indexed(spans, src_indices, src_values, dest_array=None):
dest_array = np.zeros(len(spans)-1, dtype=spans.dtype) if dest_array is None else dest_array
for i in range(len(spans)-1):
cur = spans[i]
next = spans[i+1]
Expand Down Expand Up @@ -817,7 +838,8 @@ def apply_spans_index_of_min_indexed(spans, src_indices, src_values, dest_array)


@exetera_njit
def apply_spans_index_of_max_indexed(spans, src_indices, src_values, dest_array):
def apply_spans_index_of_max_indexed(spans, src_indices, src_values, dest_array=None):
dest_array = np.zeros(len(spans) - 1, dtype=spans.dtype) if dest_array is None else dest_array
for i in range(len(spans)-1):
cur = spans[i]
next = spans[i+1]
Expand Down Expand Up @@ -855,8 +877,10 @@ def apply_spans_index_of_max_indexed(spans, src_indices, src_values, dest_array)
return dest_array



@exetera_njit
def apply_spans_index_of_max(spans, src_array, dest_array):
def apply_spans_index_of_max(spans, src_array, dest_array=None):
dest_array = np.zeros(len(spans)-1, dtype=spans.dtype) if dest_array is None else dest_array
for i in range(len(spans)-1):
cur = spans[i]
next = spans[i+1]
Expand All @@ -870,13 +894,17 @@ def apply_spans_index_of_max(spans, src_array, dest_array):


@exetera_njit
def apply_spans_index_of_first(spans, dest_array):
def apply_spans_index_of_first(spans, dest_array=None):
dest_array = np.zeros(len(spans) - 1, dtype=spans.dtype) if dest_array is None else dest_array
dest_array[:] = spans[:-1]
return dest_array


@exetera_njit
def apply_spans_index_of_last(spans, dest_array):
def apply_spans_index_of_last(spans, dest_array=None):
dest_array = np.zeros(len(spans) - 1, dtype=spans.dtype) if dest_array is None else dest_array
dest_array[:] = spans[1:] - 1
return dest_array


@exetera_njit
Expand Down Expand Up @@ -941,57 +969,68 @@ def apply_spans_index_of_last_filter(spans, dest_array, filter_array):
return dest_array, filter_array



@exetera_njit
def apply_spans_count(spans, dest_array):
def apply_spans_count(spans, dest_array=None):
if dest_array is None:
dest_array = np.zeros(len(spans) - 1, np.int64)
for i in range(len(spans)-1):
dest_array[i] = np.int64(spans[i+1] - spans[i])
dest_array[i] = spans[i+1] - spans[i]
return dest_array


@exetera_njit
def apply_spans_first(spans, src_array, dest_array):
def apply_spans_first(spans, src_array, dest_array=None):
dest_array = np.zeros(len(spans) - 1, dtype=src_array.dtype) if dest_array is None else dest_array
dest_array[:] = src_array[spans[:-1]]
return dest_array


@exetera_njit
def apply_spans_last(spans, src_array, dest_array):
def apply_spans_last(spans, src_array, dest_array=None):
dest_array = np.zeros(len(spans) - 1, dtype=src_array.dtype) if dest_array is None else dest_array
spans = spans[1:]-1
dest_array[:] = src_array[spans]
return dest_array


@exetera_njit
def apply_spans_max(spans, src_array, dest_array):

for i in range(len(spans)-1):
def apply_spans_max(spans, src_array, dest_array=None):
dest_array = np.zeros(len(spans) - 1, dtype=src_array.dtype) if dest_array is None else dest_array
for i in range(len(spans) - 1):
cur = spans[i]
next = spans[i+1]
next = spans[i + 1]
if next - cur == 1:
dest_array[i] = src_array[cur]
else:
# dest_array[i] = src_array[cur:next].max() # doesn't work for fixed strings in Python?
max_val=src_array[cur]
for idx in range(cur+1,next):
if src_array[idx]>max_val:
max_val=src_array[idx]

dest_array[i]=max_val
max_val = src_array[cur]
for idx in range(cur + 1, next):
if src_array[idx] > max_val:
max_val = src_array[idx]

dest_array[i] = max_val
return dest_array

@exetera_njit
def apply_spans_min(spans, src_array, dest_array):

for i in range(len(spans)-1):
@exetera_njit
def apply_spans_min(spans, src_array, dest_array=None):
dest_array = np.zeros(len(spans) - 1, dtype=src_array.dtype) if dest_array is None else dest_array
for i in range(len(spans) - 1):
cur = spans[i]
next = spans[i+1]
next = spans[i + 1]
if next - cur == 1:
dest_array[i] = src_array[cur]
else:
# dest_array[i] = src_array[cur:next].min() # doesn't work for fixed strings in Python?
min_val=src_array[cur]
for idx in range(cur+1,next):
if src_array[idx]<min_val:
min_val=src_array[idx]

dest_array[i]=min_val
min_val = src_array[cur]
for idx in range(cur + 1, next):
if src_array[idx] < min_val:
min_val = src_array[idx]

dest_array[i] = min_val
return dest_array



# def _apply_spans_concat(spans, src_field):
Expand Down
37 changes: 14 additions & 23 deletions exetera/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from exetera.core import operations as ops
from exetera.core import dataset as ds
from exetera.core import dataframe as df
from exetera.core import utils


class Session(AbstractSession):
Expand Down Expand Up @@ -425,18 +426,13 @@ def _apply_spans_no_src(self,
:param dest: if set, the field to which the results are written
:returns: A numpy array containing the resulting values
"""
assert (dest is None or isinstance(dest, Field))
assert (dest is None or isinstance(dest, Field)) # dest is None or field

if dest is not None:
dest_f = val.field_from_parameter(self, 'dest', dest)
results = np.zeros(len(spans) - 1, dtype=dest_f.data.dtype)
predicate(spans, results)
dest_f.data.write(results)
return results
else:
results = np.zeros(len(spans) - 1, dtype='int64')
predicate(spans, results)
return results
results = predicate(spans)
if dest is not None: # dest is a field
assert (results.dtype.type == dest.data.dtype.type, 'The field dtype does not match with the data type.')
dest.data.write(results)
return results

def _apply_spans_src(self,
predicate: Callable[[np.ndarray, np.ndarray, np.ndarray], None],
Expand All @@ -453,24 +449,19 @@ def _apply_spans_src(self,
:param dest: if set, the field to which the results are written
:returns: A numpy array containing the resulting values
"""
assert (dest is None or isinstance(dest, Field))
assert (dest is None or isinstance(dest, Field)) # dest is None or a field
target_ = val.array_from_parameter(self, 'target', target)
if len(target) != spans[-1]:
error_msg = ("'target' (length {}) must be one element shorter than 'spans' "
"(length {})")
raise ValueError(error_msg.format(len(target_), len(spans)))

if dest is not None:
dest_f = val.field_from_parameter(self, 'dest', dest)
results = np.zeros(len(spans) - 1, dtype=dest_f.data.dtype)
predicate(spans, target_, results)
dest_f.data.write(results)
return results
else:
data_type = 'int32' if len(spans) < 2000000000 else 'int64'
results = np.zeros(len(spans) - 1, dtype=data_type)
predicate(spans, target_, results)
return results
results = predicate(spans, target_)

if dest is not None: # dest is a field
assert (results.dtype.type == dest.data.dtype.type, 'The field dtype does not match with the data type.')
dest.data.write(results)
return results

def apply_spans_index_of_min(self,
spans: np.ndarray,
Expand Down
4 changes: 3 additions & 1 deletion exetera/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

SECONDS_PER_DAY = 86400
PERMITTED_NUMERIC_TYPES = ('float32', 'float64', 'bool', 'int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'int64')
INT64_INDEX_LENGTH = 2**31-1


# environment variable used to toggle Numba off for testing
USE_NUMBA_VAR = "USE_NUMBA"
Expand Down Expand Up @@ -435,4 +437,4 @@ def guess_encoding(filename):
return "utf-8-sig"
else:
return "utf-8"

25 changes: 25 additions & 0 deletions tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from exetera.core import operations as ops
from exetera.core import utils

from .utils import slow_test


class TestOpsUtils(unittest.TestCase):

Expand Down Expand Up @@ -1265,6 +1267,29 @@ def test_get_spans_two_field(self):
spans3= ops._get_spans_for_2_fields_by_spans(spans1,spans2)
self.assertTrue(list(spans), list(spans3))

@slow_test
def test_get_spans_two_field(self):
data1 = np.zeros(utils.INT64_INDEX_LENGTH + 1, 'int8')
data2 = np.zeros(utils.INT64_INDEX_LENGTH + 1, 'int8')
spans = ops._get_spans_for_2_fields(data1, data2)
self.assertEqual(spans.dtype, 'int64')

def test_get_spans_for_multi_fields(self):
data1 = np.array([1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5])
data2 = np.array([1, 2, 2, 3, 3, 3, 4, 4, 5, 6, 7, 8, 9, 10])
data3 = np.array([1, 2, 1, 2, 2, 1, 2, 3, 3, 1, 2, 3, 4, 10])
spans = ops._get_spans_for_multi_fields(np.asarray([data1, data2, data3]))
self.assertListEqual(spans.tolist(), [0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])
self.assertEqual(spans.dtype, 'int32')

@slow_test
def test_get_spans_for_multi_fields_int64(self):
data1 = np.zeros(utils.INT64_INDEX_LENGTH+1, 'int8')
data2 = np.zeros(utils.INT64_INDEX_LENGTH + 1, 'int8')
data3 = np.zeros(utils.INT64_INDEX_LENGTH + 1, 'int8')
spans = ops._get_spans_for_multi_fields(np.asarray([data1, data2, data3]))
self.assertEqual(spans.dtype, 'int64')


class TestCheckIfSorted(unittest.TestCase):

Expand Down
Loading

0 comments on commit b9a9d64

Please sign in to comment.