Skip to content

Commit

Permalink
Allow partial application for tables, disable unstable-fn for primitives
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-Fischman committed Oct 24, 2024
1 parent 586512f commit 2185096
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 25 deletions.
4 changes: 0 additions & 4 deletions benches/example_benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ pub fn criterion_benchmark(c: &mut Criterion) {
if path_string.contains("python_array_optimize") {
continue;
}
// skip unstable_fn because partial application is banned
if path_string.contains("unstable-fn") {
continue;
}

let name = path.file_stem().unwrap().to_string_lossy().to_string();
let filename = path.to_string_lossy().to_string();
Expand Down
19 changes: 15 additions & 4 deletions src/sort/fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,13 +334,24 @@ impl PrimitiveLike for Ctor {
})
}

fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option<Value> {
fn apply(&self, values: &[Value], egraph: Option<&mut EGraph>) -> Option<Value> {
let egraph = egraph.expect("`unstable-fn` is not supported yet in facts.");
let name = Symbol::load(&StringSort, &values[0]);

// TODO: solve static partial application
assert_eq!(values.len(), 1, "partial application banned");
let schema = if let Some(f) = egraph.functions.get(&name) {
&f.schema
} else {
panic!("`unstable-fn` only supports tables, found {name}")
};

assert!(values[1..].len() <= schema.input.len());
let args: Vec<(ArcSort, Value)> = values[1..]
.iter()
.zip(&schema.input)
.map(|(value, sort)| (sort.clone(), *value))
.collect();

ValueFunction(name, Vec::new()).store(&self.function)
ValueFunction(name, args).store(&self.function)
}
}

Expand Down
19 changes: 2 additions & 17 deletions tests/unstable-fn.egg
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,16 @@

(sort MathFn (UnstableFn (Math) Math))


(function square (Math) Math)
(rewrite (square x) (Mul x x))

(let square-fn (unstable-fn "square" ))

;; test that we can call a function
;; Test that we can call a function
(let squared-3 (unstable-app square-fn (Num 3)))
(check (= squared-3 (square (Num 3))))

;; test that we can apply a function to a list

;; Test that we can apply a function to a list
(function list-map-math (MathList MathFn) MathList)
(rewrite (list-map-math (Nil) fn) (Nil))
(rewrite (list-map-math (Cons x xs) fn) (Cons (unstable-app fn x) (list-map-math xs fn)))
Expand All @@ -34,7 +32,6 @@
(check (= squared-x (Cons (Num 1) (Cons (Num 4) (Cons (Num 9) (Nil))))))

;; Test that we can partially apply a function in a rewrite rule

(function list-multiply-by (MathList Math) MathList)
(rewrite (list-multiply-by l i) (list-map-math l (unstable-fn "Mul" i)))

Expand All @@ -43,7 +40,6 @@
(check (= doubled-x (Cons (Num 2) (Cons (Num 4) (Cons (Num 6) (Nil))))))

;; Test we can define a higher order compose function

(function composed-math (MathFn MathFn Math) Math)
(rewrite (composed-math f g v) (unstable-app f (unstable-app g v)))

Expand All @@ -53,17 +49,6 @@
(run-schedule (saturate (run)))
(check (= squared-doubled-x (Cons (Num 4) (Cons (Num 16) (Cons (Num 36) (Nil))))))


;; See that it supports primitive values as well
(sort i64Fun (UnstableFn (i64) i64))

(function composed-i64-math (MathFn i64Fun i64) Math)
(rewrite (composed-i64-math f g v) (unstable-app f (Num (unstable-app g v))))

(let res (composed-i64-math square-fn (unstable-fn "*" 2) 4))
(run-schedule (saturate (run)))
(check (= res (Num 64)))

;; Verify that function parsing works with a function with no args
(sort TestNullaryFunction (UnstableFn () Math))
;; Verify that we know the type of a function based on the string name
Expand Down

0 comments on commit 2185096

Please sign in to comment.